Быстрое вычисление квадрата

Чтобы ускорить мои бигнум деления мне нужно ускорить операциюy = x^2 для bigints, которые представлены в виде динамических массивов беззнаковых DWORD. Чтобы было ясно:

DWORD x[n+1] = { LSW, ......, MSW };
где n + 1 - количество используемых DWORDтак значение числаx = x[0]+x[1]<<32 + ... x[N]<<32*(n)

Вопрос в том:Как мне вычислитьy = x^2 максимально быстро без потери точности? - С помощьюC ++ и с целочисленной арифметикой (32 бита с Carry) в распоряжении.

Мой текущий подход заключается в применении умноженияy = x*x и избежать многократного умножения.

Например:

x = x[0] + x[1]<<32 + ... x[n]<<32*(n)

Для простоты позвольте мне переписать это:

x = x0+ x1 + x2 + ... + xn

где индекс представляет адрес внутри массива, поэтому:

y = x*x
y = (x0 + x1 + x2 + ...xn)*(x0 + x1 + x2 + ...xn)
y = x0*(x0 + x1 + x2 + ...xn) + x1*(x0 + x1 + x2 + ...xn) + x2*(x0 + x1 + x2 + ...xn) + ...xn*(x0 + x1 + x2 + ...xn)

y0     = x0*x0
y1     = x1*x0 + x0*x1
y2     = x2*x0 + x1*x1 + x0*x2
y3     = x3*x0 + x2*x1 + x1*x2
...
y(2n-3) = xn(n-2)*x(n  ) + x(n-1)*x(n-1) + x(n  )*x(n-2)
y(2n-2) = xn(n-1)*x(n  ) + x(n  )*x(n-1)
y(2n-1) = xn(n  )*x(n  )

При ближайшем рассмотрении становится ясно, что почти всеxi*xj появляется дважды (не первый и не последний), что означает, чтоN*N умножения могут быть заменены(N+1)*(N/2) умножения. Постскриптум32bit*32bit = 64bit так что результат каждогоmul+add операция обрабатывается как64+1 bit.

Есть ли лучший способ вычислить это быстро? Все, что я нашел во время поисков, было алгоритмами sqrts, а не sqr ...

Fast sqr

!!! Помните, что все числа в моем коде - сначала MSW, а не как в предыдущем тесте (сначала LSW для простоты уравнений, иначе это был бы беспорядок в индексе).

Текущая функциональная реализация fsqr

void arbnum::sqr(const arbnum &x)
{
    // O((N+1)*N/2)
    arbnum c;
    DWORD h, l;
    int N, nx, nc, i, i0, i1, k;
    c._alloc(x.siz + x.siz + 1);
    nx = x.siz - 1;
    nc = c.siz - 1;
    N = nx + nx;
    for (i=0; i<=nc; i++)
        c.dat[i]=0;
    for (i=1; i<N; i++)
        for (i0=0; (i0<=nx) && (i0<=i); i0++)
        {
            i1 = i - i0;
            if (i0 >= i1)
                break;
            if (i1 > nx)
                continue;
            h = x.dat[nx-i0];
            if (!h)
                continue;
            l = x.dat[nx-i1];
            if (!l)
                continue;
            alu.mul(h, l, h, l);
            k = nc - i;
            if (k >= 0)
                alu.add(c.dat[k], c.dat[k], l);
            k--;
            if (k>=0)
                alu.adc(c.dat[k], c.dat[k],h);
            k--;
            for (; (alu.cy) && (k>=0); k--)
                alu.inc(c.dat[k]);
        }
        c.shl(1);
        for (i = 0; i <= N; i += 2)
        {
            i0 = i>>1;
            h = x.dat[nx-i0];
            if (!h)
                continue;
            alu.mul(h, l, h, h);
            k = nc - i;
            if (k >= 0)
                alu.add(c.dat[k], c.dat[k],l);
            k--;
            if (k>=0)
                alu.adc(c.dat[k], c.dat[k], h);
            k--;
            for (; (alu.cy) && (k >= 0); k--)
                alu.inc(c.dat[k]);
        }
        c.bits = c.siz<<5;
        c.exp = x.exp + x.exp + ((c.siz - x.siz - x.siz)<<5) + 1;
        c.sig = sig;
        *this = c;
    }

Использование умножения Карацубы

(спасибо Калпису)

Я реализовал умножение Карацубы, но результаты значительно медленнее, чем при использовании простыхO(N^2) умножение, вероятно, из-за той ужасной рекурсии, которую я не вижу способа избежать. Это компромисс должен быть в действительно больших количествах (больше чем сотни цифр) ... но даже тогда есть много передач памяти. Есть ли способ избежать рекурсивных вызовов (нерекурсивный вариант, ... Почти все рекурсивные алгоритмы могут быть выполнены таким образом). Тем не менее, я постараюсь изменить ситуацию и посмотреть, что произойдет (избегайте нормализаций и т. Д., Также это может быть глупой ошибкой в коде). Во всяком случае, после решения Карацуба на случайx*x прирост производительности невелик.

Оптимизировано умножение Карацубы

Тест производительности дляy = x^2 looped 1000x times, 0.9 < x < 1 ~ 32*98 bits:

x = 0.98765588997654321000000009876... | 98*32 bits
sqr [ 213.989 ms ] ... O((N+1)*N/2) fast sqr
mul1[ 363.472 ms ] ... O(N^2) classic multiplication
mul2[ 349.384 ms ] ... O(3*(N^log2(3))) optimized Karatsuba multiplication
mul3[ 9345.127 ms] ... O(3*(N^log2(3))) unoptimized Karatsuba multiplication

x = 0.98765588997654321000... | 195*32 bits
sqr [ 883.01 ms ]
mul1[ 1427.02 ms ]
mul2[ 1089.84 ms ]

x = 0.98765588997654321000... | 389*32 bits
sqr [ 3189.19 ms ]
mul1[ 5553.23 ms ]
mul2[ 3159.07 ms ]

После оптимизации для Karatsuba код стал значительно быстрее, чем раньше. Тем не менее, для меньших чисел это чуть меньше половины скорости моегоO(N^2) умножение. Для больших чисел это быстрее с коэффициентом, определяемым сложностями умножения Бута. Пороговое значение для умножения составляет около 32 * 98 битов, а для sqr - около 32 * 389 битов, поэтому, если сумма входных битов пересекает этот порог, то умножение Карацубы будет использовано для ускорения умножения, что также будет схожим для sqr.

Кстати, оптимизации включены:

Свести к минимуму кучи с помощью слишком большого аргумента рекурсииВместо этого используется предотвращение любых 32-битных ALU с арифметикой bignum (+, -).игнорирование0*y или жеx*0 или же0*0 случаиПереформатирование вводаx,y количество чисел к степени два, чтобы избежать перераспределенияРеализация умножения по модулю дляz1 = (x0 + x1)*(y0 + y1) минимизировать рекурсию

Модифицированное умножение Шёнхаге-Штрассена для реализации sqr

Я проверил использованиеFFT а такжеNTT преобразовывает, чтобы ускорить вычисление sqr. Результаты таковы:

FFT

Потерять точность и, следовательно, нужны высокоточные комплексные числа Это на самом деле значительно замедляет процесс, поэтому ускорение отсутствует. Результат не является точным (может быть ошибочно округлен), поэтомуFFT непригоден (пока)

NTT

NTT конечное полеДПФ и поэтому не происходит потеря точности. Нужна модульная арифметика на целых числах без знака:modpow, modmul, modadd а такжеmodsub.

я используюDWORD (32-разрядные целые числа без знака).NTT Размер вектора ввода / вывода ограничен из-за проблем переполнения !!! Для 32-битной модульной арифметикиN ограничено(2^32)/(max(input[])^2) такbigint должен быть разделен на более мелкие куски (я используюBYTES поэтому максимальный размерbigint обработано

(2^32)/((2^8)^2) = 2^16 bytes = 2^14 DWORDs = 16384 DWORDs)

sqr использует только1xNTT + 1xINTT вместо2xNTT + 1xINTT для умножения, ноNTT использование слишком медленное и размер порогового числа слишком велик для практического использования в моей реализации (дляmul а также дляsqr).

Возможно даже превышение предела переполнения, поэтому следует использовать 64-битную модульную арифметику, которая может еще больше замедлить работу. ТакNTT для моих целей тоже непригодна.

Некоторые измерения:

a = 0.98765588997654321000 | 389*32 bits
looped 1x times
sqr1[ 3.177 ms ] fast sqr
sqr2[ 720.419 ms ] NTT sqr
mul1[ 5.588 ms ] simpe mul
mul2[ 3.172 ms ] karatsuba mul
mul3[ 1053.382 ms ] NTT mul

Моя реализация:

void arbnum::sqr_NTT(const arbnum &x)
{
    // O(N*log(N)*(log(log(N)))) - 1x NTT
    // Schönhage-Strassen sqr
    // To prevent NTT overflow: n <= 48K * 8 bit -> result siz <= 12K * 32 bit -> x.siz + y.siz <= 12K!!!
    int i, j, k, n;
    int s = x.sig*x.sig, exp0 = x.exp + x.exp - ((x.siz+x.siz)<<5) + 2;
    i = x.siz;
    for (n = 1; n < i; n<<=1)
        ;
    if (n + n > 0x3000) {
        _error(_arbnum_error_TooBigNumber);
        zero();
        return;
    }
    n <<= 3;
    DWORD *xx, *yy, q, qq;
    xx = new DWORD[n+n];
    #ifdef _mmap_h
    if (xx)
        mmap_new(xx, (n+n) << 2);
    #endif
    if (xx==NULL) {
        _error(_arbnum_error_NotEnoughMemory);
        zero();
        return;
    }
    yy = xx + n;

    // Zero padding (and split DWORDs to BYTEs)
    for (i--, k=0; i >= 0; i--)
    {
        q = x.dat[i];
        xx[k] = q&0xFF; k++; q>>=8;
        xx[k] = q&0xFF; k++; q>>=8;
        xx[k] = q&0xFF; k++; q>>=8;
        xx[k] = q&0xFF; k++;
    }
    for (;k<n;k++)
        xx[k] = 0;

    //NTT
    fourier_NTT ntt;

    ntt.NTT(yy,xx,n);    // init NTT for n

    // Convolution
    for (i=0; i<n; i++)
        yy[i] = modmul(yy[i], yy[i], ntt.p);

    //INTT
    ntt.INTT(xx, yy);

    //suma
    q=0;
    for (i = 0, j = 0; i<n; i++) {
        qq = xx[i];
        q += qq&0xFF;
        yy[n-i-1] = q&0xFF;
        q>>=8;
        qq>>=8;
        q+=qq;
    }

    // Merge WORDs to DWORDs and copy them to result
    _alloc(n>>2);
    for (i = 0, j = 0; i<siz; i++)
    {
        q  =(yy[j]<<24)&0xFF000000; j++;
        q |=(yy[j]<<16)&0x00FF0000; j++;
        q |=(yy[j]<< 8)&0x0000FF00; j++;
        q |=(yy[j]    )&0x000000FF; j++;
        dat[i] = q;
    }

    #ifdef _mmap_h
    if (xx)
        mmap_del(xx);
    #endif
    delete xx;
    bits = siz<<5;
    sig = s;
    exp = exp0 + (siz<<5) - 1;
        // _normalize();
    }

Заключение

Для меньших номеров это лучший вариант мой быстрыйsqr подход, и после порогаКарацуба умножение лучше. Но я все еще думаю, что должно быть что-то тривиальное, что мы упустили из виду. У кого-нибудь есть другие идеи?

NTT оптимизация

После чрезвычайно интенсивных оптимизаций (в основномNTT): Вопрос переполнения стекаМодульная арифметика и NTT (конечно-полевые DFT) оптимизации.

Некоторые значения изменились:

a = 0.98765588997654321000 | 1553*32bits
looped 10x times
mul2[ 28.585 ms ] Karatsuba mul
mul3[ 26.311 ms ] NTT mul

А сейчасNTT умножение, наконец, быстрее, чемКарацуба после 1500 * 32-битного порога.

Некоторые измерения и ошибка обнаружены

a = 0.99991970486 | 1553*32 bits
looped: 10x
sqr1[  58.656 ms ] fast sqr
sqr2[  13.447 ms ] NTT sqr
mul1[ 102.563 ms ] simpe mul
mul2[  28.916 ms ] Karatsuba mul Error
mul3[  19.470 ms ] NTT mul

Я узнал, что мойКарацуба (больше / меньше) течетLSB каждогоDWORD сегмент бигнум. Когда я исследую, я обновлю код ...

Кроме того, послеNTT оптимизации пороги изменились, поэтому дляNTT sqr это310*32 bits = 9920 bits изоперанд, и дляNTT mul это1396*32 bits = 44672 bits изрезультат (сумма битов операндов).

Код Карацубы исправлен благодаря @greybeard

//---------------------------------------------------------------------------
void arbnum::_mul_karatsuba(DWORD *z, DWORD *x, DWORD *y, int n)
{
    // Recursion for Karatsuba
    // z[2n] = x[n]*y[n];
    // n=2^m
    int i;
    for (i=0; i<n; i++)
        if (x[i]) {
            i=-1;
            break;
        } // x==0 ?

    if (i < 0)
        for (i = 0; i<n; i++)
            if (y[i]) {
                i = -1;
                break;
            } // y==0 ?

    if (i >= 0) {
        for (i = 0; i < n + n; i++)
            z[i]=0;
            return;
        } // 0.? = 0

    if (n == 1) {
        alu.mul(z[0], z[1], x[0], y[0]);
        return;
    }

    if (n< 1)
        return;
    int n2 = n>>1;
    _mul_karatsuba(z+n, x+n2, y+n2, n2);                         // z0 = x0.y0
    _mul_karatsuba(z  , x   , y   , n2);                         // z2 = x1.y1
    DWORD *q = new DWORD[n<<1], *q0, *q1, *qq;
    BYTE cx,cy;
    if (q == NULL) {
        _error(_arbnum_error_NotEnoughMemory);
        return;
    }
    #define _add { alu.add(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.adc(qq[i], q0[i], q1[i]); } // qq = q0 + q1 ...[i..0]
    #define _sub { alu.sub(qq[i], q0[i], q1[i]); for (i--; i>=0; i--) alu.sbc(qq[i], q0[i], q1[i]); } // qq = q0 - q1 ...[i..0]
    qq = q;
    q0 = x + n2;
    q1 = x;
    i = n2 - 1;
    _add;
    cx = alu.cy; // =x0+x1

    qq = q + n2;
    q0 = y + n2;
    q1 = y;
    i = n2 - 1;
    _add;
    cy = alu.cy; // =y0+y1

    _mul_karatsuba(q + n, q + n2, q, n2);                       // =(x0+x1)(y0+y1) mod ((2^N)-1)

    if (cx) {
        qq = q + n;
        q0 = qq;
        q1 = q + n2;
        i = n2 - 1;
        _add;
        cx = alu.cy;
    }// += cx*(y0 + y1) << n2

    if (cy) {
        qq = q + n;
        q0 = qq;
        q1 = q;
        i = n2 -1;
        _add;
        cy = alu.cy;
    }// +=cy*(x0+x1)<<n2

    qq = q + n;  q0 = qq; q1 = z + n; i = n - 1; _sub;  // -=z0
    qq = q + n;  q0 = qq; q1 = z;     i = n - 1; _sub;  // -=z2
    qq = z + n2; q0 = qq; q1 = q + n; i = n - 1; _add;  // z1=(x0+x1)(y0+y1)-z0-z2

    DWORD ccc=0;

    if (alu.cy)
        ccc++;    // Handle carry from last operation
    if (cx || cy)
        ccc++;    // Handle carry from before last operation
    if (ccc)
    {
        i = n2 - 1;
        alu.add(z[i], z[i], ccc);
        for (i--; i>=0; i--)
            if (alu.cy)
                alu.inc(z[i]);
            else
                break;
    }

    delete[] q;
    #undef _add
    #undef _sub
    }

//---------------------------------------------------------------------------
void arbnum::mul_karatsuba(const arbnum &x, const arbnum &y)
{
    // O(3*(N)^log2(3)) ~ O(3*(N^1.585))
    // Karatsuba multiplication
    //
    int s = x.sig*y.sig;
    arbnum a, b;
    a = x;
    b = y;
    a.sig = +1;
    b.sig = +1;
    int i, n;
    for (n = 1; (n < a.siz) || (n < b.siz); n <<= 1)
        ;
    a._realloc(n);
    b._realloc(n);
    _alloc(n + n);
    for (i=0; i < siz; i++)
        dat[i]=0;
    _mul_karatsuba(dat, a.dat, b.dat, n);
    bits = siz << 5;
    sig = s;
    exp = a.exp + b.exp + ((siz-a.siz-b.siz)<<5) + 1;
    //    _normalize();
    }
//---------------------------------------------------------------------------

мойarbnum представление чисел:

// dat is MSDW first ... LSDW last
DWORD *dat; int siz,exp,sig,bits;
dat[siz] это богомол LSDW означает наименее значимый DWORD.exp является показателем MSBdat[0]

Первый ненулевой бит присутствует в мантиссе !!!

// |-----|---------------------------|---------------|------|
// | sig | MSB      mantisa      LSB |   exponent    | bits |
// |-----|---------------------------|---------------|------|
// | +1  | 0.(0      ...          0) | 2^0           |   0  | +zero
// | -1  | 0.(0      ...          0) | 2^0           |   0  | -zero
// |-----|---------------------------|---------------|------|
// | +1  | 1.(dat[0] ... dat[siz-1]) | 2^exp         |   n  | +number
// | -1  | 1.(dat[0] ... dat[siz-1]) | 2^exp         |   n  | -number
// |-----|---------------------------|---------------|------|
// | +1  | 1.0                       | 2^+0x7FFFFFFE |   1  | +infinity
// | -1  | 1.0                       | 2^+0x7FFFFFFE |   1  | -infinity
// |-----|---------------------------|---------------|------|

Ответы на вопрос(2)

Ваш ответ на вопрос