NTT w / Montgomery Multiplication

时间:2014-06-18 01:13:27

标签: c++ algorithm multiplication modular-arithmetic

在过去的几天里,我一直在努力帮助Spektre先生,由于兼容性问题,他必须为FFT乘法编写自己的数字理论变换。
Modular arithmetics and NTT (finite field DFT) optimizations

他有一个效果很好,但他一直想知道是否有任何方法可以加快速度。想到的一个想法是使用蒙哥马利乘法以避免过度分裂。我过去曾经使用它,但由于某种原因我不能在这里工作,而且我不确定这是蒙哥马利乘法或NTT的问题

它使用32位字大小,因此减少也是2 ^ 32,素数模数是3221225473。使用Ext。欧几里德算法,我发现反转是:
2 ^ 32 * 2415919104 =(3221225473 * 3221225471)+ 1

下面是我正在处理的代码,主要功能是调用它 注意:此时我并不担心逆变换,因为如果常规变换根本不起作用,那就没有意义了。

#include <string.h>


#ifndef uint32
#define uint32 unsigned long int
#endif

#ifndef uint64
#define uint64 unsigned long long int
#endif


class montgom_ntt                                   // number theoretic transform
{
public:
    montgom_ntt()
    {
        r = 0; L = 0; 
        W = 0, N = 0;
    }
    // main interface
    void  NTT(uint32 *dst, uint32 *src, uint32 n = 0);             // uint32 dst[n] = fast  NTT(uint32 src[n])

private:
    bool init(uint32 n);                                     // init r,L,p,W,iW,rN
    void NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w);  // uint32 dst[n] = fast  NTT(uint32 src[n])

    void  NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w);  // uint32 dst[n] = fast  NTT(uint32 src[n])
    void NTT_fast(uint32 *dst, const uint32 *src, uint32 n, uint32 w);

    // uint32 arithmetics
    public:
    uint32 montgom_in(uint32 n);
    uint32 montgom_out(uint32 n);
    void montgom_in_arr(uint32* dst, const uint32* src, uint32 n);
    void montgom_out_arr(uint32* dst, const uint32* src, uint32 n);

    private:

    // modular arithmetics
    inline uint32 modadd(uint32 a, uint32 b);
    inline uint32 modsub(uint32 a, uint32 b);
    inline uint32 modmul(uint32 a, uint32 b);
    inline uint32 modpow(uint32 a, uint32 b);

    uint32 r, L, N, W;

    const uint32 p = 0xC0000001;
    const uint64 px = 0xC0000001;
};

//---------------------------------------------------------------------------
bool montgom_ntt::init(uint32 n)
{
    // (max(src[])^2)*n < p else NTT overflow can ocur !!!
    r = 2;

    if ((n < 2) || (n > 0x10000000))
    {
        r = 0; L = 0; W = 0; // p = 0;
        iW = 0; rN = 0; N = 0;
        return false;
    }
    L = 0x30000000 / n; // 32:30 bit best for unsigned 32 bit

    N = n;               // size of vectors [uint32s]
    W = modpow(r, L); // Wn for NTT
    W = montgom_in(W);

    return true;
}

//---------------------------------------------------------------------------
void montgom_ntt::NTT(uint32 *dst, uint32 *src, uint32 n)
{
    if (n > 0)
    {
        init(n);
    }
    NTT_fast(dst, src, N, W);
}


//---------------------------------------------------------------------------
void montgom_ntt::NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
    if (n > 1)
    {
        if (dst != src)
        {
            NTT_calc(dst, src, n, w);
        }
        else
        {
            uint32* temp = new uint32[n];
            memcpy(temp, src, sizeof(uint32) * n);
            NTT_calc(dst, temp, n, w);

            delete[] temp;
        }
    }
    else if (n == 1)
    {
        dst[0] = src[0];
    }
}


void montgom_ntt::NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
    if (n > 1)
    {
        uint32 i, j, a0, a1,
            n2 = n >> 1,
            w2 = modmul(w, w);

        // reorder even,odd
        for (i = 0, j = 0; i < n2; i++, j += 2)
        {
            dst[i] = src[j];
        }
        for (j = 1; i < n; i++, j += 2)
        {
            dst[i] = src[j];
        }
        // recursion
        if (n2 > 1)
        {
            NTT_calc(src, dst, n2, w2);  // even
            NTT_calc(src + n2, dst + n2, n2, w2);  // odd
        }
        else if (n2 == 1)
        {
            src[0] = dst[0];
            src[1] = dst[1];
        }

        // restore results

        w2 = 1, i = 0, j = n2;
        a0 = src[i];
        a1 = src[j];

        dst[i] = modadd(a0, a1);
        dst[j] = modsub(a0, a1);
        while (++i < n2)
        {
            w2 = modmul(w2, w);
            j++;
            a0 = src[i];
            a1 = modmul(src[j], w2);
            dst[i] = modadd(a0, a1);
            dst[j] = modsub(a0, a1);
        }
    }
}

//---------------------------------------------------------------------------
void montgom_ntt::NTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
    uint32 i, j, wj, wi, a,
        n2 = n >> 1;
    for (wj = 1, j = 0; j < n; j++)
    {
        a = 0;
        for (wi = 1, i = 0; i < n; i++)
        {
            a = modadd(a, modmul(wi, src[i]));
            wi = modmul(wi, wj);
        }
        dst[j] = a;
        wj = modmul(wj, w);
    }
}


//---------------------------------------------------------------------------
uint32 montgom_ntt::montgom_in(uint32 n)
{
    uint64 N = n;
    N = (N << 32) % px;
    return N;
}

//---------------------------------------------------------------------------
uint32 montgom_ntt::montgom_out(uint32 n)
{
    const uint64 C = 0x90000000;
    uint64 N = n;
    N *= C;
    N %= px;

    return N;
}

//---------------------------------------------------------------------------
void montgom_ntt::montgom_in_arr(uint32* dst, const uint32* src, uint32 n)
{
    uint32 I = 0;

    do
    {
        dst[I] = montgom_in(src[I]);
    } while (++I < n);
}

//---------------------------------------------------------------------------
void montgom_ntt::montgom_out_arr(uint32* dst, const uint32* src, uint32 n)
{
    uint32 I = 0;

    do
    {
        dst[I] = montgom_out(src[I]);
    } while (++I < n);
}

//---------------------------------------------------------------------------
uint32 montgom_ntt::modadd(uint32 a, uint32 b)
{
    uint32 n = a + b;
    if (n < a)
    {
        n -= p;
    }
    else if (n >= p)
    {
        n -= p;
    }
    return n;
}

//---------------------------------------------------------------------------
uint32 montgom_ntt::modsub(uint32 a, uint32 b)
{
    uint32 d;

    d = a - b;
    if(a < b)
    {
        d += p;
        d = (a + p) - b;
    }
    return d;
}

//---------------------------------------------------------------------------
uint32 montgom_ntt::modmul(uint32 a, uint32 b)
{
    uint64 A(a), B(b), C;
    uint32 R;

    A *= B;
    C = A & 0xFFFFFFFF;
    C *= 0xBFFFFFFF;

    C = (C & 0xFFFFFFFF) * px;
    C += A;
    R = (C >> 32);
    if(C < A)
    {
        R -= p;
    }
    if(R >= p) 
    {
        R -= p;
    }   
    return R;

}

//---------------------------------------------------------------------------
uint32 montgom_ntt::modpow(uint32 a, uint32 b)
{
    //*
    uint64 D, M, A;

    P = p; A = a;
    M = 0llu - (b & 1);
    D = (M & A) | ((~M) & 1);

    while ((b >>= 1) != 0)
    {
        A = (A * A) % P;

        if ((b & 1) == 1)
        {
            D = (D * A) % P;
        }
    }
    return (uint32)D;
}

和这里的主要

void main()
{
    montgom_ntt F;

    uint32 Tran[8];
    uint32 Arr[8] = 
    {
        0x2923, 0xbe84,
        0xe16c, 0xd6ae,
        0, 0, 0, 0
    };

    F.montgom_in_arr(Arr1, Arr1, Len);
    F.NTT(Tran, Arr, Len);
    F.montgom_out_arr(Tran, Tran, Len);

}

我感觉它很简单,但我无法弄清楚它是什么。 感谢您提供的任何帮助!

[编辑] 因此,为了排除它,我修改了modmul函数,以便将其输入从蒙哥马利形式转换为常规形式,执行标准(A * B)%p,然后将其转换回蒙哥马利形式,我仍然得到同样,错误的答案。这让我觉得问题是蒙哥马利形式的转换,但我不知道我做错了什么。

uint32 montgom_ntt::modmul(uint32 a, uint32 b)
{
    uint64 A, B, C;

    A = montgom_out(a);
    B = montgom_out(b);
    C = (A * B) % px;
    return montgom_in(C);        

    /*
    uint64 A(a), B(b), C;
    uint32 R;

    A *= B;
    C = A & 0xFFFFFFFF;
    C *= 0xBFFFFFFF;

    C = (C & 0xFFFFFFFF) * px;
    C += A;
    R = (C >> 32);
    if(C < A)
    {
        R -= p;
    }
    if(R >= p) 
    {
        R -= p;
    }   
    return R;
    */
}

0 个答案:

没有答案