我一直在学习fft,尤其是在竞争性编码竞赛中使用。一些谷歌搜索后我想出了一个python实现。为了测试它,我也为多项式乘法创建了一个天真的实现。
代码如下:
from cmath import pi, cos, sin
mod = 100003
def fft(x, inv):
n = len(x)
j = 0
for i in xrange(1, n):
b = n >> 1
while j >= b:
j -= b
b >>= 1
j += b
if j > i:
x[i], x[j] = x[j], x[i]
trans_size = 2
while trans_size <= n:
w = 1
ang = [-2 * pi / trans_size, 2 * pi / trans_size][inv]
wn = complex(cos(ang), sin(ang))
for t in xrange(trans_size >> 1):
for trans in xrange(n / trans_size):
i = trans * trans_size + t
j = i + (trans_size >> 1)
a = x[i]
b = x[j] * w
x[i] = a + b
x[j] = a - b
w *= wn
trans_size <<= 1
def polymul(x, y):
lx = len(x)
ly = len(y)
lz = 1 << (lx + ly - 1).bit_length()
x += [0] * (lz - lx)
y += [0] * (lz - ly)
fft(x, 0)
fft(y, 0)
z = []
for i in xrange(lz):
z.append(x[i] * y[i])
fft(z, 1)
for i in xrange(lx + ly - 1):
z[i] = int(z[i].real/lz + 0.5) % mod
return z[:lx + ly - 1]
def naive_polymul(x, y):
lx = len(x)
ly = len(y)
z = [0 for i in xrange(lx + ly - 1)]
for i in xrange(lx):
for j in xrange(ly):
z[i + j] += x[i] * y[j]
for i in xrange(lx + ly - 1):
z[i] %= mod
return z
from random import randint
N = input()
A = [randint(1, 100000) for i in xrange(N)]
B = [randint(1, 100000) for i in xrange(N)]
C = naive_polymul(A, B)
D = polymul(A, B)
assert(C == D)
现在,我给出随机生成的两个多项式的大小作为输入。直到4000它工作正常,但在5000断言失败。
>>> ================================ RESTART ================================
>>>
4000
>>> ================================ RESTART ================================
>>>
5000
Traceback (most recent call last):
File "fft.py", line 75, in <module>
assert(C == D)
AssertionError
我稍微使用了代码,我很确定问题在于浮点值引起的精度误差,系数值总是偏离1.提高精度的唯一方法我知道是十进制但是它显着减慢了代码的速度,这与我在这里完全相反。
因此,我在(https://www.hackerrank.com/challenges/emma-and-sum-of-products)上测试我的代码的问题在较大的测试用例中给出了错误的答案,而C ++提交(来自其他用户)使用标准的long double传递它们很好。< / p>
我在这里缺少什么?如果有人能指出我正确的方向来解决这个问题,我将不胜感激。
谢谢!