Computação quadrada bignum rápida
Para acelerar minhas divisões bignum, preciso acelerar a operaçãoy = x^2
para bigints representados como matrizes dinâmicas de DWORDs não assinadas. Para ser claro:
DWORD x[n+1] = { LSW, ......, MSW };
onde n + 1 é o número de DWORDs usadosentão valor do númerox = x[0]+x[1]<<32 + ... x[N]<<32*(n)
A questão é:Como faço para calculary = x^2
o mais rápido possível sem perda de precisão? - UsandoC ++ e com aritmética inteira (32 bits com Carry) à disposição.
Minha abordagem atual é aplicar multiplicaçãoy = x*x
e evite múltiplas multiplicações.
Por exemplo:
x = x[0] + x[1]<<32 + ... x[n]<<32*(n)
Para simplificar, deixe-me reescrevê-lo:
x = x0+ x1 + x2 + ... + xn
onde index representa o endereço dentro da matriz, então:
y = x*x
y = (x0 + x1 + x2 + ...xn)*(x0 + x1 + x2 + ...xn)
y = x0*(x0 + x1 + x2 + ...xn) + x1*(x0 + x1 + x2 + ...xn) + x2*(x0 + x1 + x2 + ...xn) + ...xn*(x0 + x1 + x2 + ...xn)
y0 = x0*x0
y1 = x1*x0 + x0*x1
y2 = x2*x0 + x1*x1 + x0*x2
y3 = x3*x0 + x2*x1 + x1*x2
...
y(2n-3) = xn(n-2)*x(n ) + x(n-1)*x(n-1) + x(n )*x(n-2)
y(2n-2) = xn(n-1)*x(n ) + x(n )*x(n-1)
y(2n-1) = xn(n )*x(n )
Após uma análise mais detalhada, fica claro que quase todosxi*xj
aparece duas vezes (não o primeiro e o último), o que significa queN*N
multiplicações podem ser substituídas por(N+1)*(N/2)
multiplicações. P.S.32bit*32bit = 64bit
então o resultado de cadamul+add
operação é tratada como64+1 bit
.
Existe uma maneira melhor de calcular isso rápido? Tudo o que encontrei durante as pesquisas foram algoritmos sqrts, não sqr ...
Sqr rápido
!!! Lembre-se de que todos os números no meu código são MSW primeiro, ... não como no teste acima (existem LSW primeiro pela simplicidade das equações, caso contrário, seria uma bagunça de índice).
Implementação funcional atual do fsqr
void arbnum::sqr(const arbnum &x)
{
// O((N+1)*N/2)
arbnum c;
DWORD h, l;
int N, nx, nc, i, i0, i1, k;
c._alloc(x.siz + x.siz + 1);
nx = x.siz - 1;
nc = c.siz - 1;
N = nx + nx;
for (i=0; i<=nc; i++)
c.dat[i]=0;
for (i=1; i<N; i++)
for (i0=0; (i0<=nx) && (i0<=i); i0++)
{
i1 = i - i0;
if (i0 >= i1)
break;
if (i1 > nx)
continue;
h = x.dat[nx-i0];
if (!h)
continue;
l = x.dat[nx-i1];
if (!l)
continue;
alu.mul(h, l, h, l);
k = nc - i;
if (k >= 0)
alu.add(c.dat[k], c.dat[k], l);
k--;
if (k>=0)
alu.adc(c.dat[k], c.dat[k],h);
k--;
for (; (alu.cy) && (k>=0); k--)
alu.inc(c.dat[k]);
}
c.shl(1);
for (i = 0; i <= N; i += 2)
{
i0 = i>>1;
h = x.dat[nx-i0];
if (!h)
continue;
alu.mul(h, l, h, h);
k = nc - i;
if (k >= 0)
alu.add(c.dat[k], c.dat[k],l);
k--;
if (k>=0)
alu.adc(c.dat[k], c.dat[k], h);
k--;
for (; (alu.cy) && (k >= 0); k--)
alu.inc(c.dat[k]);
}
c.bits = c.siz<<5;
c.exp = x.exp + x.exp + ((c.siz - x.siz - x.siz)<<5) + 1;
c.sig = sig;
*this = c;
}
Uso da multiplicação de Karatsuba
(graças a Calpis)
Eu implementei a multiplicação de Karatsuba, mas os resultados são massivamente mais lentos do que pelo uso de simplesO(N^2)
multiplicação, provavelmente por causa daquela horrível recursão que não vejo como evitar. O trade-off deve estar em números realmente grandes (maiores que centenas de dígitos) ... mas mesmo assim há muitas transferências de memória. Existe uma maneira de evitar chamadas de recursão (variante não recursiva, ... Quase todos os algoritmos recursivos podem ser feitos dessa maneira). Ainda assim, tentarei ajustar as coisas e ver o que acontece (evitar normalizações, etc ..., também pode haver algum erro bobo no código). De qualquer forma, depois de resolver o Karatsuba para o casox*x
não há muito ganho de desempenho.
Multiplicação otimizada de Karatsuba
Teste de desempenho paray = x^2 looped 1000x times, 0.9 < x < 1 ~ 32*98 bits
:
x = 0.98765588997654321000000009876... | 98*32 bits
sqr [ 213.989 ms ] ... O((N+1)*N/2) fast sqr
mul1[ 363.472 ms ] ... O(N^2) classic multiplication
mul2[ 349.384 ms ] ... O(3*(N^log2(3))) optimized Karatsuba multiplication
mul3[ 9345.127 ms] ... O(3*(N^log2(3))) unoptimized Karatsuba multiplication
x = 0.98765588997654321000... | 195*32 bits
sqr [ 883.01 ms ]
mul1[ 1427.02 ms ]
mul2[ 1089.84 ms ]
x = 0.98765588997654321000... | 389*32 bits
sqr [ 3189.19 ms ]
mul1[ 5553.23 ms ]
mul2[ 3159.07 ms ]
Após otimizações para o Karatsuba, o código é muito mais rápido do que antes. Ainda assim, para números menores, é um pouco menos da metade da velocidade do meuO(N^2)
multiplicação. Para números maiores, é mais rápido com a proporção dada pelas complexidades das multiplicações de Booth. O limite para multiplicação é de cerca de 32 * 98 bits e para o sqr, de 32 * 389 bits; portanto, se a soma dos bits de entrada ultrapassar esse limite, a multiplicação do Karatsuba será usada para acelerar a multiplicação e isso também será semelhante ao sqr.
BTW, as otimizações incluíram:
Minimize a lixeira do heap pelo argumento de recursão muito grandePara evitar qualquer aritmética bignum (+, -) a ALU de 32 bits com carry é usada.Ignorando0*y
oux*0
ou0*0
casosReformatação de entradax,y
tamanhos de número com potência de dois para evitar a realocaçãoImplementar a multiplicação do módulo paraz1 = (x0 + x1)*(y0 + y1)
para minimizar a recursãoMultiplicação de Schönhage-Strassen modificada para implementação de sqr
Eu testei o uso deFFT eNTT transforma para acelerar a computação sqr. Os resultados são os seguintes:
FFT
Perde a precisão e, portanto, precisa de números complexos de alta precisão. Isso realmente torna as coisas consideravelmente mais lentas, para que não haja aceleração. O resultado não é preciso (pode ser arredondado incorretamente), portantoFFT está inutilizável (por enquanto)
NTT
NTT é um campo finitoDFT e, portanto, nenhuma perda de precisão ocorre. Precisa de aritmética modular em números inteiros não assinados:modpow, modmul, modadd
emodsub
.
eu usoDWORD
(Números inteiros não assinados de 32 bits). oNTT O tamanho do vetor de entrada / otput é limitado devido a problemas de estouro !!! Para aritmética modular de 32 bits,N
é limitado a(2^32)/(max(input[])^2)
tãobigint
deve ser dividido em pedaços menores (eu usoBYTES
tamanho máximo debigint
processado é
(2^32)/((2^8)^2) = 2^16 bytes = 2^14 DWORDs = 16384 DWORDs)
osqr
usa apenas1xNTT + 1xINTT
ao invés de2xNTT + 1xINTT
para multiplicação, masNTT o uso é muito lento e o tamanho do número limite é muito grande para uso prático em minha implementação (pormul
e também parasqr
)
É possível que esteja acima do limite de estouro, portanto, aritméticas modulares de 64 bits devem ser usadas, o que pode atrasar ainda mais as coisas. assimNTT também é inutilizável para os meus propósitos.
Algumas medidas:
a = 0.98765588997654321000 | 389*32 bits
looped 1x times
sqr1[ 3.177 ms ] fast sqr
sqr2[ 720.419 ms ] NTT sqr
mul1[ 5.588 ms ] simpe mul
mul2[ 3.172 ms ] karatsuba mul
mul3[ 1053.382 ms ] NTT mul
Minha implementação:
void arbnum::sqr_NTT(const arbnum &x)
{
// O(N*log(N)*(log(log(N)))) - 1x NTT
// Schönhage-Strassen sqr
// To prevent NTT overflow: n <= 48K * 8 bit -> result siz <= 12K * 32 bit -> x.siz + y.siz <= 12K!!!
int i, j, k, n;
int s = x.sig*x.sig, exp0 = x.exp + x.exp - ((x.siz+x.siz)<<5) + 2;
i = x.siz;
for (n = 1; n < i; n<<=1)
;
if (n + n > 0x3000) {
_error(_arbnum_error_TooBigNumber);
zero();
return;
}
n <<= 3;
DWORD *xx, *yy, q, qq;
xx = new DWORD[n+n];
#ifdef _mmap_h
if (xx)
mmap_new(xx, (n+n) << 2);
#endif
if (xx==NULL) {
_error(_arbnum_error_NotEnoughMemory);
zero();
return;
}
yy = xx + n;
// Zero padding (and split DWORDs to BYTEs)
for (i--, k=0; i >= 0; i--)
{
q = x.dat[i];
xx[k] = q&0xFF; k++; q>>=8;
xx[k] = q&0xFF; k++; q>>=8;
xx[k] = q&0xFF; k++; q>>=8;
xx[k] = q&0xFF; k++;
}
for (;k<n;k++)
xx[k] = 0;
//NTT
fourier_NTT ntt;
ntt.NTT(yy,xx,n); // init NTT for n
// Convolution
for (i=0; i<n; i++)
yy[i] = modmul(yy[i], yy[i], ntt.p);
//INTT
ntt.INTT(xx, yy);
//suma
q=0;
for (i = 0, j = 0; i<n; i++) {
qq = xx[i];
q += qq&0xFF;
yy[n-i-1] = q&0xFF;
q>>=8;
qq>>=8;
q+=qq;
}
// Merge WORDs to DWORDs and copy them to result
_alloc(n>>2);
for (i = 0, j = 0; i<siz; i++)
{
q =(yy[j]<<24)&0xFF000000; j++;
q |=(yy[j]<<16)&0x00FF0000; j++;
q |=(yy[j]<< 8)&0x0000FF00; j++;
q |=(yy[j] )&0x000000FF; j++;
dat[i] = q;
}
#ifdef _mmap_h
if (xx)
mmap_del(xx);
#endif
delete xx;
bits = siz<<5;
sig = s;
exp = exp0 + (siz<<5) - 1;
// _normalize();
}
Conclusão
Para números menores, é a melhor opção para meu rápidosqr
abordagem e após o limiarKaratsuba multiplicação é melhor. Mas ainda acho que deve haver algo trivial que negligenciamos. Alguém tem outras idéias?
Otimização NTT
Após otimizações intensamente intensas (principalmenteNTT): Pergunta de estouro de pilhaAritmética modular e otimizações NTT (DFT de campo finito).
Alguns valores foram alterados:
a = 0.98765588997654321000 | 1553*32bits
looped 10x times
mul2[ 28.585 ms ] Karatsuba mul
mul3[ 26.311 ms ] NTT mul
Então agoraNTT finalmente a multiplicação é mais rápida queKaratsuba após um limite de cerca de 1500 * 32 bits.
Algumas medições e erros detectados
a = 0.99991970486 | 1553*32 bits
looped: 10x
sqr1[ 58.656 ms ] fast sqr
sqr2[ 13.447 ms ] NTT sqr
mul1[ 102.563 ms ] simpe mul
mul2[ 28.916 ms ] Karatsuba mul Error
mul3[ 19.470 ms ] NTT mul
Eu descobri que meuKaratsuba (acima / abaixo) flui oLSB De cadaDWORD
segmento de bignum. Quando eu pesquisar, atualizarei o código ...
Além disso, depois de maisNTT otimizações, os limites foram alterados, portanto, paraNTT sqr isto é310*32 bits = 9920 bits
dooperando, e paraNTT mul isto é1396*32 bits = 44672 bits
doresultado (soma de bits de operandos).
Código Karatsuba reparado graças a @greybeard
//---------------------------------------------------------------------------
void arbnum::_mul_karatsuba(DWORD *z, DWORD *x, DWORD *y, int n)
{
// Recursion for Karatsuba
// z[2n] = x[n]*y[n];
// n=2^m
int i;
for (i=0; i<n; i++)
if (x[i]) {
i=-1;
break;
} // x==0 ?
if (i < 0)
for (i = 0; i<n; i++)
if (y[i]) {
i = -1;
break;
} // y==0 ?
if (i >= 0) {
for (i = 0; i < n + n; i++)
z[i]=0;
return;
} // 0.? = 0
if (n == 1) {
alu.mul(z[0], z[1], x[0], y[0]);
return;
}
if (n< 1)
return;
int n2 = n>>1;
_mul_karatsuba(z+n, x+n2, y+n2, n2); // z0 = x0.y0
_mul_karatsuba(z , x , y , n2); // z2 = x1.y1
DWORD *q = new DWORD[n<<1], *q0, *q1, *qq;
BYTE cx,cy;
if (q == NULL) {
_error(_arbnum_error_NotEnoughMemory);
return;
}
#define _add { alu.add(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.adc(qq[i], q0[i], q1[i]); } // qq = q0 + q1 ...[i..0]
#define _sub { alu.sub(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.sbc(qq[i], q0[i], q1[i]); } // qq = q0 - q1 ...[i..0]
qq = q;
q0 = x + n2;
q1 = x;
i = n2 - 1;
_add;
cx = alu.cy; // =x0+x1
qq = q + n2;
q0 = y + n2;
q1 = y;
i = n2 - 1;
_add;
cy = alu.cy; // =y0+y1
_mul_karatsuba(q + n, q + n2, q, n2); // =(x0+x1)(y0+y1) mod ((2^N)-1)
if (cx) {
qq = q + n;
q0 = qq;
q1 = q + n2;
i = n2 - 1;
_add;
cx = alu.cy;
}// += cx*(y0 + y1) << n2
if (cy) {
qq = q + n;
q0 = qq;
q1 = q;
i = n2 -1;
_add;
cy = alu.cy;
}// +=cy*(x0+x1)<<n2
qq = q + n; q0 = qq; q1 = z + n; i = n - 1; _sub; // -=z0
qq = q + n; q0 = qq; q1 = z; i = n - 1; _sub; // -=z2
qq = z + n2; q0 = qq; q1 = q + n; i = n - 1; _add; // z1=(x0+x1)(y0+y1)-z0-z2
DWORD ccc=0;
if (alu.cy)
ccc++; // Handle carry from last operation
if (cx || cy)
ccc++; // Handle carry from before last operation
if (ccc)
{
i = n2 - 1;
alu.add(z[i], z[i], ccc);
for (i--; i>=0; i--)
if (alu.cy)
alu.inc(z[i]);
else
break;
}
delete[] q;
#undef _add
#undef _sub
}
//---------------------------------------------------------------------------
void arbnum::mul_karatsuba(const arbnum &x, const arbnum &y)
{
// O(3*(N)^log2(3)) ~ O(3*(N^1.585))
// Karatsuba multiplication
//
int s = x.sig*y.sig;
arbnum a, b;
a = x;
b = y;
a.sig = +1;
b.sig = +1;
int i, n;
for (n = 1; (n < a.siz) || (n < b.siz); n <<= 1)
;
a._realloc(n);
b._realloc(n);
_alloc(n + n);
for (i=0; i < siz; i++)
dat[i]=0;
_mul_karatsuba(dat, a.dat, b.dat, n);
bits = siz << 5;
sig = s;
exp = a.exp + b.exp + ((siz-a.siz-b.siz)<<5) + 1;
// _normalize();
}
//---------------------------------------------------------------------------
Minhasarbnum
representação numérica:
// dat is MSDW first ... LSDW last
DWORD *dat; int siz,exp,sig,bits;
dat[siz]
é a mantisa. LSDW significa DWORD menos significativo.exp
é o expoente de MSB dedat[0]
O primeiro bit diferente de zero está presente na mantissa !!!
// |-----|---------------------------|---------------|------|
// | sig | MSB mantisa LSB | exponent | bits |
// |-----|---------------------------|---------------|------|
// | +1 | 0.(0 ... 0) | 2^0 | 0 | +zero
// | -1 | 0.(0 ... 0) | 2^0 | 0 | -zero
// |-----|---------------------------|---------------|------|
// | +1 | 1.(dat[0] ... dat[siz-1]) | 2^exp | n | +number
// | -1 | 1.(dat[0] ... dat[siz-1]) | 2^exp | n | -number
// |-----|---------------------------|---------------|------|
// | +1 | 1.0 | 2^+0x7FFFFFFE | 1 | +infinity
// | -1 | 1.0 | 2^+0x7FFFFFFE | 1 | -infinity
// |-----|---------------------------|---------------|------|