在过去的几天里,我一直在努力帮助Spektre先生,由于兼容性问题,他必须为FFT乘法编写自己的数字理论变换。
Modular arithmetics and NTT (finite field DFT) optimizations
他有一个效果很好,但他一直想知道是否有任何方法可以加快速度。想到的一个想法是使用蒙哥马利乘法以避免过度分裂。我过去曾经使用它,但由于某种原因我不能在这里工作,而且我不确定这是蒙哥马利乘法或NTT的问题
它使用32位字大小,因此减少也是2 ^ 32,素数模数是3221225473。使用Ext。欧几里德算法,我发现反转是:
2 ^ 32 * 2415919104 =(3221225473 * 3221225471)+ 1
下面是我正在处理的代码,主要功能是调用它 注意:此时我并不担心逆变换,因为如果常规变换根本不起作用,那就没有意义了。
#include <string.h>
#ifndef uint32
#define uint32 unsigned long int
#endif
#ifndef uint64
#define uint64 unsigned long long int
#endif
class montgom_ntt // number theoretic transform
{
public:
montgom_ntt()
{
r = 0; L = 0;
W = 0, N = 0;
}
// main interface
void NTT(uint32 *dst, uint32 *src, uint32 n = 0); // uint32 dst[n] = fast NTT(uint32 src[n])
private:
bool init(uint32 n); // init r,L,p,W,iW,rN
void NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w); // uint32 dst[n] = fast NTT(uint32 src[n])
void NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w); // uint32 dst[n] = fast NTT(uint32 src[n])
void NTT_fast(uint32 *dst, const uint32 *src, uint32 n, uint32 w);
// uint32 arithmetics
public:
uint32 montgom_in(uint32 n);
uint32 montgom_out(uint32 n);
void montgom_in_arr(uint32* dst, const uint32* src, uint32 n);
void montgom_out_arr(uint32* dst, const uint32* src, uint32 n);
private:
// modular arithmetics
inline uint32 modadd(uint32 a, uint32 b);
inline uint32 modsub(uint32 a, uint32 b);
inline uint32 modmul(uint32 a, uint32 b);
inline uint32 modpow(uint32 a, uint32 b);
uint32 r, L, N, W;
const uint32 p = 0xC0000001;
const uint64 px = 0xC0000001;
};
//---------------------------------------------------------------------------
bool montgom_ntt::init(uint32 n)
{
// (max(src[])^2)*n < p else NTT overflow can ocur !!!
r = 2;
if ((n < 2) || (n > 0x10000000))
{
r = 0; L = 0; W = 0; // p = 0;
iW = 0; rN = 0; N = 0;
return false;
}
L = 0x30000000 / n; // 32:30 bit best for unsigned 32 bit
N = n; // size of vectors [uint32s]
W = modpow(r, L); // Wn for NTT
W = montgom_in(W);
return true;
}
//---------------------------------------------------------------------------
void montgom_ntt::NTT(uint32 *dst, uint32 *src, uint32 n)
{
if (n > 0)
{
init(n);
}
NTT_fast(dst, src, N, W);
}
//---------------------------------------------------------------------------
void montgom_ntt::NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
if (n > 1)
{
if (dst != src)
{
NTT_calc(dst, src, n, w);
}
else
{
uint32* temp = new uint32[n];
memcpy(temp, src, sizeof(uint32) * n);
NTT_calc(dst, temp, n, w);
delete[] temp;
}
}
else if (n == 1)
{
dst[0] = src[0];
}
}
void montgom_ntt::NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
if (n > 1)
{
uint32 i, j, a0, a1,
n2 = n >> 1,
w2 = modmul(w, w);
// reorder even,odd
for (i = 0, j = 0; i < n2; i++, j += 2)
{
dst[i] = src[j];
}
for (j = 1; i < n; i++, j += 2)
{
dst[i] = src[j];
}
// recursion
if (n2 > 1)
{
NTT_calc(src, dst, n2, w2); // even
NTT_calc(src + n2, dst + n2, n2, w2); // odd
}
else if (n2 == 1)
{
src[0] = dst[0];
src[1] = dst[1];
}
// restore results
w2 = 1, i = 0, j = n2;
a0 = src[i];
a1 = src[j];
dst[i] = modadd(a0, a1);
dst[j] = modsub(a0, a1);
while (++i < n2)
{
w2 = modmul(w2, w);
j++;
a0 = src[i];
a1 = modmul(src[j], w2);
dst[i] = modadd(a0, a1);
dst[j] = modsub(a0, a1);
}
}
}
//---------------------------------------------------------------------------
void montgom_ntt::NTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
uint32 i, j, wj, wi, a,
n2 = n >> 1;
for (wj = 1, j = 0; j < n; j++)
{
a = 0;
for (wi = 1, i = 0; i < n; i++)
{
a = modadd(a, modmul(wi, src[i]));
wi = modmul(wi, wj);
}
dst[j] = a;
wj = modmul(wj, w);
}
}
//---------------------------------------------------------------------------
uint32 montgom_ntt::montgom_in(uint32 n)
{
uint64 N = n;
N = (N << 32) % px;
return N;
}
//---------------------------------------------------------------------------
uint32 montgom_ntt::montgom_out(uint32 n)
{
const uint64 C = 0x90000000;
uint64 N = n;
N *= C;
N %= px;
return N;
}
//---------------------------------------------------------------------------
void montgom_ntt::montgom_in_arr(uint32* dst, const uint32* src, uint32 n)
{
uint32 I = 0;
do
{
dst[I] = montgom_in(src[I]);
} while (++I < n);
}
//---------------------------------------------------------------------------
void montgom_ntt::montgom_out_arr(uint32* dst, const uint32* src, uint32 n)
{
uint32 I = 0;
do
{
dst[I] = montgom_out(src[I]);
} while (++I < n);
}
//---------------------------------------------------------------------------
uint32 montgom_ntt::modadd(uint32 a, uint32 b)
{
uint32 n = a + b;
if (n < a)
{
n -= p;
}
else if (n >= p)
{
n -= p;
}
return n;
}
//---------------------------------------------------------------------------
uint32 montgom_ntt::modsub(uint32 a, uint32 b)
{
uint32 d;
d = a - b;
if(a < b)
{
d += p;
d = (a + p) - b;
}
return d;
}
//---------------------------------------------------------------------------
uint32 montgom_ntt::modmul(uint32 a, uint32 b)
{
uint64 A(a), B(b), C;
uint32 R;
A *= B;
C = A & 0xFFFFFFFF;
C *= 0xBFFFFFFF;
C = (C & 0xFFFFFFFF) * px;
C += A;
R = (C >> 32);
if(C < A)
{
R -= p;
}
if(R >= p)
{
R -= p;
}
return R;
}
//---------------------------------------------------------------------------
uint32 montgom_ntt::modpow(uint32 a, uint32 b)
{
//*
uint64 D, M, A;
P = p; A = a;
M = 0llu - (b & 1);
D = (M & A) | ((~M) & 1);
while ((b >>= 1) != 0)
{
A = (A * A) % P;
if ((b & 1) == 1)
{
D = (D * A) % P;
}
}
return (uint32)D;
}
和这里的主要
void main()
{
montgom_ntt F;
uint32 Tran[8];
uint32 Arr[8] =
{
0x2923, 0xbe84,
0xe16c, 0xd6ae,
0, 0, 0, 0
};
F.montgom_in_arr(Arr1, Arr1, Len);
F.NTT(Tran, Arr, Len);
F.montgom_out_arr(Tran, Tran, Len);
}
我感觉它很简单,但我无法弄清楚它是什么。 感谢您提供的任何帮助!
[编辑] 因此,为了排除它,我修改了modmul函数,以便将其输入从蒙哥马利形式转换为常规形式,执行标准(A * B)%p,然后将其转换回蒙哥马利形式,我仍然得到同样,错误的答案。这让我觉得问题是蒙哥马利形式的转换,但我不知道我做错了什么。
uint32 montgom_ntt::modmul(uint32 a, uint32 b)
{
uint64 A, B, C;
A = montgom_out(a);
B = montgom_out(b);
C = (A * B) % px;
return montgom_in(C);
/*
uint64 A(a), B(b), C;
uint32 R;
A *= B;
C = A & 0xFFFFFFFF;
C *= 0xBFFFFFFF;
C = (C & 0xFFFFFFFF) * px;
C += A;
R = (C >> 32);
if(C < A)
{
R -= p;
}
if(R >= p)
{
R -= p;
}
return R;
*/
}