python precision + fft

时间:2015-07-19 10:06:23

标签: python fft precision

我一直在学习fft,尤其是在竞争性编码竞赛中使用。一些谷歌搜索后我想出了一个python实现。为了测试它,我也为多项式乘法创建了一个天真的实现。

代码如下:

from cmath import pi, cos, sin
mod = 100003

def fft(x, inv):
    n = len(x)
    j = 0
    for i in xrange(1, n):
        b = n >> 1
        while j >= b:
            j -= b
            b >>= 1
        j += b
        if j > i:
            x[i], x[j] = x[j], x[i]
    trans_size = 2
    while trans_size <= n:
        w = 1
        ang = [-2 * pi / trans_size, 2 * pi / trans_size][inv]
        wn = complex(cos(ang), sin(ang))
        for t in xrange(trans_size >> 1):
            for trans in xrange(n / trans_size):
                i = trans * trans_size + t
                j = i + (trans_size >> 1)
                a = x[i]
                b = x[j] * w
                x[i] = a + b
                x[j] = a - b
            w *= wn
        trans_size <<= 1

def polymul(x, y):
    lx = len(x)
    ly = len(y)
    lz = 1 << (lx + ly - 1).bit_length()
    x += [0] * (lz - lx)
    y += [0] * (lz - ly)
    fft(x, 0)
    fft(y, 0)
    z = []
    for i in xrange(lz):
        z.append(x[i] * y[i])
    fft(z, 1)
    for i in xrange(lx + ly - 1):
        z[i] = int(z[i].real/lz + 0.5) % mod
    return z[:lx + ly - 1]

def naive_polymul(x, y):
    lx = len(x)
    ly = len(y)
    z = [0 for i in xrange(lx + ly - 1)]
    for i in xrange(lx):
        for j in xrange(ly):
            z[i + j] += x[i] * y[j]
    for i in xrange(lx + ly - 1):
        z[i] %= mod
    return z

from random import randint

N = input()
A = [randint(1, 100000) for i in xrange(N)]
B = [randint(1, 100000) for i in xrange(N)]
C = naive_polymul(A, B)
D = polymul(A, B)
assert(C == D)

现在,我给出随机生成的两个多项式的大小作为输入。直到4000它工作正常,但在5000断言失败。

>>> ================================ RESTART ================================
>>> 
4000
>>> ================================ RESTART ================================
>>> 
5000

Traceback (most recent call last):
  File "fft.py", line 75, in <module>
    assert(C == D)
AssertionError

我稍微使用了代码,我很确定问题在于浮点值引起的精度误差,系数值总是偏离1.提高精度的唯一方法我知道是十进制但是它显着减慢了代码的速度,这与我在这里完全相反。

因此,我在(https://www.hackerrank.com/challenges/emma-and-sum-of-products)上测试我的代码的问题在较大的测试用例中给出了错误的答案,而C ++提交(来自其他用户)使用标准的long double传递它们很好。< / p>

我在这里缺少什么?如果有人能指出我正确的方向来解决这个问题,我将不胜感激。

谢谢!

0 个答案:

没有答案