Computação quadrada bignum rápida

Para acelerar minhas divisões bignum, preciso acelerar a operaçãoy = x^2 para bigints representados como matrizes dinâmicas de DWORDs não assinadas. Para ser claro:

DWORD x[n+1] = { LSW, ......, MSW };
onde n + 1 é o número de DWORDs usadosentão valor do númerox = x[0]+x[1]<<32 + ... x[N]<<32*(n)

A questão é:Como faço para calculary = x^2 o mais rápido possível sem perda de precisão? - UsandoC ++ e com aritmética inteira (32 bits com Carry) à disposição.

Minha abordagem atual é aplicar multiplicaçãoy = x*x e evite múltiplas multiplicações.

Por exemplo:

x = x[0] + x[1]<<32 + ... x[n]<<32*(n)

Para simplificar, deixe-me reescrevê-lo:

x = x0+ x1 + x2 + ... + xn

onde index representa o endereço dentro da matriz, então:

y = x*x
y = (x0 + x1 + x2 + ...xn)*(x0 + x1 + x2 + ...xn)
y = x0*(x0 + x1 + x2 + ...xn) + x1*(x0 + x1 + x2 + ...xn) + x2*(x0 + x1 + x2 + ...xn) + ...xn*(x0 + x1 + x2 + ...xn)

y0     = x0*x0
y1     = x1*x0 + x0*x1
y2     = x2*x0 + x1*x1 + x0*x2
y3     = x3*x0 + x2*x1 + x1*x2
...
y(2n-3) = xn(n-2)*x(n  ) + x(n-1)*x(n-1) + x(n  )*x(n-2)
y(2n-2) = xn(n-1)*x(n  ) + x(n  )*x(n-1)
y(2n-1) = xn(n  )*x(n  )

Após uma análise mais detalhada, fica claro que quase todosxi*xj aparece duas vezes (não o primeiro e o último), o que significa queN*N multiplicações podem ser substituídas por(N+1)*(N/2) multiplicações. P.S.32bit*32bit = 64bit então o resultado de cadamul+add operação é tratada como64+1 bit.

Existe uma maneira melhor de calcular isso rápido? Tudo o que encontrei durante as pesquisas foram algoritmos sqrts, não sqr ...

Sqr rápido

!!! Lembre-se de que todos os números no meu código são MSW primeiro, ... não como no teste acima (existem LSW primeiro pela simplicidade das equações, caso contrário, seria uma bagunça de índice).

Implementação funcional atual do fsqr

void arbnum::sqr(const arbnum &x)
{
    // O((N+1)*N/2)
    arbnum c;
    DWORD h, l;
    int N, nx, nc, i, i0, i1, k;
    c._alloc(x.siz + x.siz + 1);
    nx = x.siz - 1;
    nc = c.siz - 1;
    N = nx + nx;
    for (i=0; i<=nc; i++)
        c.dat[i]=0;
    for (i=1; i<N; i++)
        for (i0=0; (i0<=nx) && (i0<=i); i0++)
        {
            i1 = i - i0;
            if (i0 >= i1)
                break;
            if (i1 > nx)
                continue;
            h = x.dat[nx-i0];
            if (!h)
                continue;
            l = x.dat[nx-i1];
            if (!l)
                continue;
            alu.mul(h, l, h, l);
            k = nc - i;
            if (k >= 0)
                alu.add(c.dat[k], c.dat[k], l);
            k--;
            if (k>=0)
                alu.adc(c.dat[k], c.dat[k],h);
            k--;
            for (; (alu.cy) && (k>=0); k--)
                alu.inc(c.dat[k]);
        }
        c.shl(1);
        for (i = 0; i <= N; i += 2)
        {
            i0 = i>>1;
            h = x.dat[nx-i0];
            if (!h)
                continue;
            alu.mul(h, l, h, h);
            k = nc - i;
            if (k >= 0)
                alu.add(c.dat[k], c.dat[k],l);
            k--;
            if (k>=0)
                alu.adc(c.dat[k], c.dat[k], h);
            k--;
            for (; (alu.cy) && (k >= 0); k--)
                alu.inc(c.dat[k]);
        }
        c.bits = c.siz<<5;
        c.exp = x.exp + x.exp + ((c.siz - x.siz - x.siz)<<5) + 1;
        c.sig = sig;
        *this = c;
    }

Uso da multiplicação de Karatsuba

(graças a Calpis)

Eu implementei a multiplicação de Karatsuba, mas os resultados são massivamente mais lentos do que pelo uso de simplesO(N^2) multiplicação, provavelmente por causa daquela horrível recursão que não vejo como evitar. O trade-off deve estar em números realmente grandes (maiores que centenas de dígitos) ... mas mesmo assim há muitas transferências de memória. Existe uma maneira de evitar chamadas de recursão (variante não recursiva, ... Quase todos os algoritmos recursivos podem ser feitos dessa maneira). Ainda assim, tentarei ajustar as coisas e ver o que acontece (evitar normalizações, etc ..., também pode haver algum erro bobo no código). De qualquer forma, depois de resolver o Karatsuba para o casox*x não há muito ganho de desempenho.

Multiplicação otimizada de Karatsuba

Teste de desempenho paray = x^2 looped 1000x times, 0.9 < x < 1 ~ 32*98 bits:

x = 0.98765588997654321000000009876... | 98*32 bits
sqr [ 213.989 ms ] ... O((N+1)*N/2) fast sqr
mul1[ 363.472 ms ] ... O(N^2) classic multiplication
mul2[ 349.384 ms ] ... O(3*(N^log2(3))) optimized Karatsuba multiplication
mul3[ 9345.127 ms] ... O(3*(N^log2(3))) unoptimized Karatsuba multiplication

x = 0.98765588997654321000... | 195*32 bits
sqr [ 883.01 ms ]
mul1[ 1427.02 ms ]
mul2[ 1089.84 ms ]

x = 0.98765588997654321000... | 389*32 bits
sqr [ 3189.19 ms ]
mul1[ 5553.23 ms ]
mul2[ 3159.07 ms ]

Após otimizações para o Karatsuba, o código é muito mais rápido do que antes. Ainda assim, para números menores, é um pouco menos da metade da velocidade do meuO(N^2) multiplicação. Para números maiores, é mais rápido com a proporção dada pelas complexidades das multiplicações de Booth. O limite para multiplicação é de cerca de 32 * 98 bits e para o sqr, de 32 * 389 bits; portanto, se a soma dos bits de entrada ultrapassar esse limite, a multiplicação do Karatsuba será usada para acelerar a multiplicação e isso também será semelhante ao sqr.

BTW, as otimizações incluíram:

Minimize a lixeira do heap pelo argumento de recursão muito grandePara evitar qualquer aritmética bignum (+, -) a ALU de 32 bits com carry é usada.Ignorando0*y oux*0 ou0*0 casosReformatação de entradax,y tamanhos de número com potência de dois para evitar a realocaçãoImplementar a multiplicação do módulo paraz1 = (x0 + x1)*(y0 + y1) para minimizar a recursão

Multiplicação de Schönhage-Strassen modificada para implementação de sqr

Eu testei o uso deFFT eNTT transforma para acelerar a computação sqr. Os resultados são os seguintes:

FFT

Perde a precisão e, portanto, precisa de números complexos de alta precisão. Isso realmente torna as coisas consideravelmente mais lentas, para que não haja aceleração. O resultado não é preciso (pode ser arredondado incorretamente), portantoFFT está inutilizável (por enquanto)

NTT

NTT é um campo finitoDFT e, portanto, nenhuma perda de precisão ocorre. Precisa de aritmética modular em números inteiros não assinados:modpow, modmul, modadd emodsub.

eu usoDWORD (Números inteiros não assinados de 32 bits). oNTT O tamanho do vetor de entrada / otput é limitado devido a problemas de estouro !!! Para aritmética modular de 32 bits,N é limitado a(2^32)/(max(input[])^2) tãobigint deve ser dividido em pedaços menores (eu usoBYTES tamanho máximo debigint processado é

(2^32)/((2^8)^2) = 2^16 bytes = 2^14 DWORDs = 16384 DWORDs)

osqr usa apenas1xNTT + 1xINTT ao invés de2xNTT + 1xINTT para multiplicação, masNTT o uso é muito lento e o tamanho do número limite é muito grande para uso prático em minha implementação (pormul e também parasqr)

É possível que esteja acima do limite de estouro, portanto, aritméticas modulares de 64 bits devem ser usadas, o que pode atrasar ainda mais as coisas. assimNTT também é inutilizável para os meus propósitos.

Algumas medidas:

a = 0.98765588997654321000 | 389*32 bits
looped 1x times
sqr1[ 3.177 ms ] fast sqr
sqr2[ 720.419 ms ] NTT sqr
mul1[ 5.588 ms ] simpe mul
mul2[ 3.172 ms ] karatsuba mul
mul3[ 1053.382 ms ] NTT mul

Minha implementação:

void arbnum::sqr_NTT(const arbnum &x)
{
    // O(N*log(N)*(log(log(N)))) - 1x NTT
    // Schönhage-Strassen sqr
    // To prevent NTT overflow: n <= 48K * 8 bit -> result siz <= 12K * 32 bit -> x.siz + y.siz <= 12K!!!
    int i, j, k, n;
    int s = x.sig*x.sig, exp0 = x.exp + x.exp - ((x.siz+x.siz)<<5) + 2;
    i = x.siz;
    for (n = 1; n < i; n<<=1)
        ;
    if (n + n > 0x3000) {
        _error(_arbnum_error_TooBigNumber);
        zero();
        return;
    }
    n <<= 3;
    DWORD *xx, *yy, q, qq;
    xx = new DWORD[n+n];
    #ifdef _mmap_h
    if (xx)
        mmap_new(xx, (n+n) << 2);
    #endif
    if (xx==NULL) {
        _error(_arbnum_error_NotEnoughMemory);
        zero();
        return;
    }
    yy = xx + n;

    // Zero padding (and split DWORDs to BYTEs)
    for (i--, k=0; i >= 0; i--)
    {
        q = x.dat[i];
        xx[k] = q&0xFF; k++; q>>=8;
        xx[k] = q&0xFF; k++; q>>=8;
        xx[k] = q&0xFF; k++; q>>=8;
        xx[k] = q&0xFF; k++;
    }
    for (;k<n;k++)
        xx[k] = 0;

    //NTT
    fourier_NTT ntt;

    ntt.NTT(yy,xx,n);    // init NTT for n

    // Convolution
    for (i=0; i<n; i++)
        yy[i] = modmul(yy[i], yy[i], ntt.p);

    //INTT
    ntt.INTT(xx, yy);

    //suma
    q=0;
    for (i = 0, j = 0; i<n; i++) {
        qq = xx[i];
        q += qq&0xFF;
        yy[n-i-1] = q&0xFF;
        q>>=8;
        qq>>=8;
        q+=qq;
    }

    // Merge WORDs to DWORDs and copy them to result
    _alloc(n>>2);
    for (i = 0, j = 0; i<siz; i++)
    {
        q  =(yy[j]<<24)&0xFF000000; j++;
        q |=(yy[j]<<16)&0x00FF0000; j++;
        q |=(yy[j]<< 8)&0x0000FF00; j++;
        q |=(yy[j]    )&0x000000FF; j++;
        dat[i] = q;
    }

    #ifdef _mmap_h
    if (xx)
        mmap_del(xx);
    #endif
    delete xx;
    bits = siz<<5;
    sig = s;
    exp = exp0 + (siz<<5) - 1;
        // _normalize();
    }

Conclusão

Para números menores, é a melhor opção para meu rápidosqr abordagem e após o limiarKaratsuba multiplicação é melhor. Mas ainda acho que deve haver algo trivial que negligenciamos. Alguém tem outras idéias?

Otimização NTT

Após otimizações intensamente intensas (principalmenteNTT): Pergunta de estouro de pilhaAritmética modular e otimizações NTT (DFT de campo finito).

Alguns valores foram alterados:

a = 0.98765588997654321000 | 1553*32bits
looped 10x times
mul2[ 28.585 ms ] Karatsuba mul
mul3[ 26.311 ms ] NTT mul

Então agoraNTT finalmente a multiplicação é mais rápida queKaratsuba após um limite de cerca de 1500 * 32 bits.

Algumas medições e erros detectados

a = 0.99991970486 | 1553*32 bits
looped: 10x
sqr1[  58.656 ms ] fast sqr
sqr2[  13.447 ms ] NTT sqr
mul1[ 102.563 ms ] simpe mul
mul2[  28.916 ms ] Karatsuba mul Error
mul3[  19.470 ms ] NTT mul

Eu descobri que meuKaratsuba (acima / abaixo) flui oLSB De cadaDWORD segmento de bignum. Quando eu pesquisar, atualizarei o código ...

Além disso, depois de maisNTT otimizações, os limites foram alterados, portanto, paraNTT sqr isto é310*32 bits = 9920 bits dooperando, e paraNTT mul isto é1396*32 bits = 44672 bits doresultado (soma de bits de operandos).

Código Karatsuba reparado graças a @greybeard

//---------------------------------------------------------------------------
void arbnum::_mul_karatsuba(DWORD *z, DWORD *x, DWORD *y, int n)
{
    // Recursion for Karatsuba
    // z[2n] = x[n]*y[n];
    // n=2^m
    int i;
    for (i=0; i<n; i++)
        if (x[i]) {
            i=-1;
            break;
        } // x==0 ?

    if (i < 0)
        for (i = 0; i<n; i++)
            if (y[i]) {
                i = -1;
                break;
            } // y==0 ?

    if (i >= 0) {
        for (i = 0; i < n + n; i++)
            z[i]=0;
            return;
        } // 0.? = 0

    if (n == 1) {
        alu.mul(z[0], z[1], x[0], y[0]);
        return;
    }

    if (n< 1)
        return;
    int n2 = n>>1;
    _mul_karatsuba(z+n, x+n2, y+n2, n2);                         // z0 = x0.y0
    _mul_karatsuba(z  , x   , y   , n2);                         // z2 = x1.y1
    DWORD *q = new DWORD[n<<1], *q0, *q1, *qq;
    BYTE cx,cy;
    if (q == NULL) {
        _error(_arbnum_error_NotEnoughMemory);
        return;
    }
    #define _add { alu.add(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.adc(qq[i], q0[i], q1[i]); } // qq = q0 + q1 ...[i..0]
    #define _sub { alu.sub(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.sbc(qq[i], q0[i], q1[i]); } // qq = q0 - q1 ...[i..0]
    qq = q;
    q0 = x + n2;
    q1 = x;
    i = n2 - 1;
    _add;
    cx = alu.cy; // =x0+x1

    qq = q + n2;
    q0 = y + n2;
    q1 = y;
    i = n2 - 1;
    _add;
    cy = alu.cy; // =y0+y1

    _mul_karatsuba(q + n, q + n2, q, n2);                       // =(x0+x1)(y0+y1) mod ((2^N)-1)

    if (cx) {
        qq = q + n;
        q0 = qq;
        q1 = q + n2;
        i = n2 - 1;
        _add;
        cx = alu.cy;
    }// += cx*(y0 + y1) << n2

    if (cy) {
        qq = q + n;
        q0 = qq;
        q1 = q;
        i = n2 -1;
        _add;
        cy = alu.cy;
    }// +=cy*(x0+x1)<<n2

    qq = q + n;  q0 = qq; q1 = z + n; i = n - 1; _sub;  // -=z0
    qq = q + n;  q0 = qq; q1 = z;     i = n - 1; _sub;  // -=z2
    qq = z + n2; q0 = qq; q1 = q + n; i = n - 1; _add;  // z1=(x0+x1)(y0+y1)-z0-z2

    DWORD ccc=0;

    if (alu.cy)
        ccc++;    // Handle carry from last operation
    if (cx || cy)
        ccc++;    // Handle carry from before last operation
    if (ccc)
    {
        i = n2 - 1;
        alu.add(z[i], z[i], ccc);
        for (i--; i>=0; i--)
            if (alu.cy)
                alu.inc(z[i]);
            else
                break;
    }

    delete[] q;
    #undef _add
    #undef _sub
    }

//---------------------------------------------------------------------------
void arbnum::mul_karatsuba(const arbnum &x, const arbnum &y)
{
    // O(3*(N)^log2(3)) ~ O(3*(N^1.585))
    // Karatsuba multiplication
    //
    int s = x.sig*y.sig;
    arbnum a, b;
    a = x;
    b = y;
    a.sig = +1;
    b.sig = +1;
    int i, n;
    for (n = 1; (n < a.siz) || (n < b.siz); n <<= 1)
        ;
    a._realloc(n);
    b._realloc(n);
    _alloc(n + n);
    for (i=0; i < siz; i++)
        dat[i]=0;
    _mul_karatsuba(dat, a.dat, b.dat, n);
    bits = siz << 5;
    sig = s;
    exp = a.exp + b.exp + ((siz-a.siz-b.siz)<<5) + 1;
    //    _normalize();
    }
//---------------------------------------------------------------------------

Minhasarbnum representação numérica:

// dat is MSDW first ... LSDW last
DWORD *dat; int siz,exp,sig,bits;
dat[siz] é a mantisa. LSDW significa DWORD menos significativo.exp é o expoente de MSB dedat[0]

O primeiro bit diferente de zero está presente na mantissa !!!

// |-----|---------------------------|---------------|------|
// | sig | MSB      mantisa      LSB |   exponent    | bits |
// |-----|---------------------------|---------------|------|
// | +1  | 0.(0      ...          0) | 2^0           |   0  | +zero
// | -1  | 0.(0      ...          0) | 2^0           |   0  | -zero
// |-----|---------------------------|---------------|------|
// | +1  | 1.(dat[0] ... dat[siz-1]) | 2^exp         |   n  | +number
// | -1  | 1.(dat[0] ... dat[siz-1]) | 2^exp         |   n  | -number
// |-----|---------------------------|---------------|------|
// | +1  | 1.0                       | 2^+0x7FFFFFFE |   1  | +infinity
// | -1  | 1.0                       | 2^+0x7FFFFFFE |   1  | -infinity
// |-----|---------------------------|---------------|------|

questionAnswers(2)

yourAnswerToTheQuestion