在有限域上实现FFT

时间:2018-09-11 06:57:12

标签: math fft dft finite-field ntt

我想使用NTT实现多项式的乘法。我遵循了Number-theoretic transform (integer DFT),它似乎可以正常工作。

现在,我想在有限域Z_p[x]上实现多项式的乘法,其中p是任意质数。

与以前的无界情况相比,它现在是否改变了系数p所界定的范围?

特别是,原始NTT需要找到质数N作为大于(magnitude of largest element of input vector)^2 * (length of input vector) + 1的工作模数,以便结果永不溢出。如果结果无论如何都会受到该p的限制,那么模数可以是多少?请注意,p - 1的格式不必为(some positive integer) * (length of input vector)

编辑:我从上面的链接中复制粘贴了源代码以说明问题:

# 
# Number-theoretic transform library (Python 2, 3)
# 
# Copyright (c) 2017 Project Nayuki
# All rights reserved. Contact Nayuki for licensing.
# https://www.nayuki.io/page/number-theoretic-transform-integer-dft
#

import itertools, numbers

def find_params_and_transform(invec, minmod):
    check_int(minmod)
    mod = find_modulus(len(invec), minmod)
    root = find_primitive_root(len(invec), mod - 1, mod)
    return (transform(invec, root, mod), root, mod)

def check_int(n):
    if not isinstance(n, numbers.Integral):
        raise TypeError()

def find_modulus(veclen, minimum):
    check_int(veclen)
    check_int(minimum)
    if veclen < 1 or minimum < 1:
        raise ValueError()
    start = (minimum - 1 + veclen - 1) // veclen
    for i in itertools.count(max(start, 1)):
        n = i * veclen + 1
        assert n >= minimum
        if is_prime(n):
            return n

def is_prime(n):
    check_int(n)
    if n <= 1:
        raise ValueError()
    return all((n % i != 0) for i in range(2, sqrt(n) + 1))

def sqrt(n):
    check_int(n)
    if n < 0:
        raise ValueError()
    i = 1
    while i * i <= n:
        i *= 2
    result = 0
    while i > 0:
        if (result + i)**2 <= n:
            result += i
        i //= 2
    return result

def find_primitive_root(degree, totient, mod):
    check_int(degree)
    check_int(totient)
    check_int(mod)
    if not (1 <= degree <= totient < mod):
        raise ValueError()
    if totient % degree != 0:
        raise ValueError()
    gen = find_generator(totient, mod)
    root = pow(gen, totient // degree, mod)
    assert 0 <= root < mod
    return root

def find_generator(totient, mod):
    check_int(totient)
    check_int(mod)
    if not (1 <= totient < mod):
        raise ValueError()
    for i in range(1, mod):
        if is_generator(i, totient, mod):
            return i
    raise ValueError("No generator exists")

def is_generator(val, totient, mod):
    check_int(val)
    check_int(totient)
    check_int(mod)
    if not (0 <= val < mod):
        raise ValueError()
    if not (1 <= totient < mod):
        raise ValueError()
    pf = unique_prime_factors(totient)
    return pow(val, totient, mod) == 1 and all((pow(val, totient // p, mod) != 1) for p in pf)

def unique_prime_factors(n):
    check_int(n)
    if n < 1:
        raise ValueError()
    result = []
    i = 2
    end = sqrt(n)
    while i <= end:
        if n % i == 0:
            n //= i
            result.append(i)
            while n % i == 0:
                n //= i
            end = sqrt(n)
        i += 1
    if n > 1:
        result.append(n)
    return result

def transform(invec, root, mod):
    check_int(root)
    check_int(mod)
    if len(invec) >= mod:
        raise ValueError()
    if not all((0 <= val < mod) for val in invec):
        raise ValueError()
    if not (1 <= root < mod):
        raise ValueError()

    outvec = []
    for i in range(len(invec)):
        temp = 0
        for (j, val) in enumerate(invec):
            temp += val * pow(root, i * j, mod)
            temp %= mod
        outvec.append(temp)
    return outvec

def inverse_transform(invec, root, mod):
    outvec = transform(invec, reciprocal(root, mod), mod)
    scaler = reciprocal(len(invec), mod)
    return [(val * scaler % mod) for val in outvec]

def reciprocal(n, mod):
    check_int(n)
    check_int(mod)
    if not (0 <= n < mod):
        raise ValueError()
    x, y = mod, n
    a, b = 0, 1
    while y != 0:
        a, b = b, a - x // y * b
        x, y = y, x % y
    if x == 1:
        return a % mod
    else:
        raise ValueError("Reciprocal does not exist")

def circular_convolve(vec0, vec1):
    if not (0 < len(vec0) == len(vec1)):
        raise ValueError()
    if any((val < 0) for val in itertools.chain(vec0, vec1)):
        raise ValueError()
    maxval = max(val for val in itertools.chain(vec0, vec1))
    minmod = maxval**2 * len(vec0) + 1
    temp0, root, mod = find_params_and_transform(vec0, minmod)
    temp1 = transform(vec1, root, mod)
    temp2 = [(x * y % mod) for (x, y) in zip(temp0, temp1)]
    return inverse_transform(temp2, root, mod)

vec0 = [24, 12, 28, 8, 0, 0, 0, 0]
vec1 = [4, 26, 29, 23, 0, 0, 0, 0]

print(circular_convolve(vec0, vec1))

def modulo(vec, prime):
    return [x % prime for x in vec]

print(modulo(circular_convolve(vec0, vec1), 31))

打印:

[96, 672, 1120, 1660, 1296, 876, 184, 0]
[3, 21, 4, 17, 25, 8, 29, 0]

但是,当我将minmod = maxval**2 * len(vec0) + 1更改为minmod = maxval + 1时,它将停止工作:

[14, 16, 13, 20, 25, 15, 20, 0]
[14, 16, 13, 20, 25, 15, 20, 0]

为了达到预期效果,最小的minmod(在上面的链接中的N)是什么?

1 个答案:

答案 0 :(得分:1)

如果您输入的n整数绑定到某个质数q(任何mod q不仅是质数都将是相同的),您可以将其用作max value +1,但请注意,您不能将其用作 NTT 的素数p,因为 NTT 素数p具有特殊的属性。他们都在这里:

因此,每个输入的最大值为q-1,但在您的任务计算过程中(对2个 NTT 结果进行卷积),第一层结果的大小可以上升到n.(q-1)但是当我们对它们进行卷积时,最终 iNTT 的输入幅度将上升为:

m = n.((q-1)^2)

如果您在 NTT 上执行的操作不同于m等式,则可能会发生变化。

现在让我们回到p,因此简而言之,您可以使用支持这些的任何素数p

p mod n == 1
p > m

并且存在1 <= r,L < p这样:

p mod (L-1) = 0
r^(L*i) mod p == 1 // i = { 0,n }
r^(L*i) mod p != 1 // i = { 1,2,3, ... n-1 }

如果这一切都得到满足,那么p将是第n个统一根,并可用于 NTT 。要找到这样的素数和r,L,请查看上面的链接(有C ++代码可以找到这样的素数。)

例如,在字符串乘法期间,我们对它们进行2个字符串 NTT ,然后对结果进行卷积,并 iNTT 返回结果(这是两个输入大小的总和)。例如:

                                99999999999999999999999999999999
                               *99999999999999999999999999999999
----------------------------------------------------------------
9999999999999999999999999999999800000000000000000000000000000001

q = 10和两个操作数均为9 ^ 32,因此n=32因此为m = 9*9*32 = 2592,找到的素数为p = 2689。如您所见,结果匹配,因此不会发生溢出。但是,如果我使用仍然适合所有其他条件的任何较小的质数,则结果将不匹配。我专门用它来尽可能地拉伸NTT值(所有值均为q-1,大小等于2的幂次)

如果您的 NTT 很快并且n不是2的幂,那么您需要将每个 NTT < / strong>。但这不会影响m的值,因为零填充不应该增加值的大小。我的测试证明了这一点,因此您可以使用卷积:

m = (n1+n2).((q-1)^2)/2

其中n1,n2是零填充之前的原始输入大小。

有关实现 NTT 的更多信息,您可以在 C ++ (经过全面优化)中检出我的信息:

因此,请回答您的问题:

  1. 是的,您可以利用以下事实:输入为mod q,但不能将q用作p

  2. 您只能将minmod = n * (maxval + 1)用于单个NTT(或NTT的第一层),但是由于在NTT使用过程中将它们与卷积链接在一起,因此不能在最后的INTT阶段使用它!!!

但是,正如我在评论中提到的那样,最简单的方法是使用最大的p来适应您正在使用的数据类型,并且可用于支持2种输​​入大小的所有幂次

这基本上使您的问题变得无关紧要。我只能想到不可能/不希望出现这种情况的唯一情况是在没有“最大”限制的任意精度数字上。绑定到变量p的性能有很多问题,因为对p的搜索确实很慢(甚至可能比 NTT 本身还要慢),而且变量p禁用了使 NTT 真正慢的所需的模块化算法的许多性能优化。