在大数乘法上使用FFT的问题

时间:2015-07-12 16:13:32

标签: c++ fft multiplication ifft

我最近试图将一堆(最多10 ^ 5)非常小的数字(10 ^ -4的数量)相互相乘,所以我会以10 ^ - (4 * 10 ^)的顺序结束5)不适合任何变量。

我的方法如下:

  1. 将每个数乘以10 ^ 8并将其存储在以10的幂分割的数组中,即我将其中10的幂给出一个多项式。一个例子是:p = 0.1234 - > p * 10 ^ 8 = 12340000 - > A = {0,0,0,0,4,3,2,1}。
  2. 使用FFT
  3. 乘以这些数组
  4. iFFT结果
  5. 对于少数不同的情况,这是多次完成的。 我最终想知道的是一个这样的产品在所有产品总和上的比例,精度为10 ^ -6。为此,在步骤2和3之间将结果添加到最后也是iFFTed的sum数组中。由于所需的精度非常低,我没有划分多项式,只是将前几个数字转换为整数。

    我的问题如下:FFT和/或iFFT无法正常工作!我是这个东西的新手,并且只实现了一次FFT。我的代码如下(C ++ 14):

    #include <cstdio>
    #include <vector>
    #include <cmath>
    #include <complex>
    #include <iostream>
    using namespace std;
    
    const double PI = 4*atan(1.0);
    
    vector<complex<double>> FFT (vector<complex<double>> A, long N)
    {
        vector<complex<double>> Ans(0);
        if(N==1) 
        {
            Ans.push_back(A[0]);
            return Ans;
        }
        vector<complex<double>> even(N/2);
        vector<complex<double>> odd(N/2);
        for(long i=0; i<(N/2); i++)
        {
            even[i] = A[2*i];
            odd[i] = A[2*i+1];
        }
        vector<complex<double>> L1 = FFT(even, N/2);
        vector<complex<double>> L2 = FFT(odd, N/2);
        for(long i=0; i<N; i++)
        {
            complex<double> z(cos(2*PI*i/N),sin(2*PI*i/N));
            long k = i%(N/2);
            Ans.push_back(L1[k] + z*L2[k]);
        }
        return Ans;
        }
    
        vector<complex<double>> iFFT (vector<complex<double>> A, long N)
        {
        vector<complex<double>> Ans(0);
        if(N==1) 
        {
            Ans.push_back(A[0]);
            return Ans;
        }
        vector<complex<double>> even(N/2);
        vector<complex<double>> odd(N/2);
        for(long i=0; i<(N/2); i++)
        {
            even[i] = A[2*i];
            odd[i] = A[2*i+1];
        }
        vector<complex<double>> L1 = FFT(even, N/2);
        vector<complex<double>> L2 = FFT(odd, N/2);
        for(long i=0; i<N; i++)
        {
            complex<double> z(cos(-2*PI*i/N),sin(-2*PI*i/N));
            complex<double> inv(double(1.0/N), 0);
            long k = i%(N/2);
            Ans.push_back(inv*(L1[k]+z*L2[k]));
        }
        return Ans;
    }
    
    vector<complex<double>> PMult (vector<complex<double>> A, vector<complex<double>> B, long L) 
    {
    vector<complex<double>> Ans(L);
    for(int i=0; i<L; i++)
    {
        Ans[i] = A[i]*B[i];
    }
    return Ans;
    }
    
    vector<complex<double>> DtoA (double x)
    {
        vector<complex<double>> ans(8);
        long X = long(x*10000000);
        ans[0] = complex<double>(double(X%10), 0.0); X/=10;
        ans[1] = complex<double>(double(X%10), 0.0); X/=10;
        ans[2] = complex<double>(double(X%10), 0.0); X/=10;
        ans[3] = complex<double>(double(X%10), 0.0); X/=10;
        ans[4] = complex<double>(double(X%10), 0.0); X/=10;
        ans[5] = complex<double>(double(X%10), 0.0); X/=10;
        ans[6] = complex<double>(double(X%10), 0.0); X/=10;
        ans[7] = complex<double>(double(X%10), 0.0);
        return ans;
    }
    
    int main()
    {
        vector<vector<complex<double>>> W;
        int T, N, M;
        double p;
        scanf("%d", &T);
        while( T-- )
        {
            scanf("%d %d", &N, &M);
            W.resize(N); 
            for(int i=0; i<N; i++)
            {
                cin >> p;
                W[i] = FFT(DtoA(p),8);
                for(int j=1; j<M; j++)
                {
                    cin >> p;
                    W[i] = PMult(W[i], FFT(DtoA(p),8), 8);
                }
            }
            vector<complex<double>> Sum(8);
            for(int j=0; j<8; j++) Sum[j]=W[0][j];
            for(int i=1; i<N; i++)
            {
                for(int j=0; j<8; j++)
                {
                    Sum[j]+=W[i][j];
                }
            }
    
            W[0]=iFFT(W[0],8);
            Sum=iFFT(Sum, 8);
    
            long X=0;
            long Y=0;
            int a;
            for(a=0; Sum[a].real()!=0; a++);
            for(int i=a; i<8; i++)
            {
                Y*=10;
                Y=Y+long(Sum[i].real());
            }
            for(int i=a; i<8; i++)
            {
                X*=10;
                X=X+long(W[0][i].real()); 
            }
            double ans = 0.0;
            if(Y) ans=double(X)/double(Y);
            printf("%.7f\n", ans);
        }
    }
    

    我观察到的是,对于除了一个条目之外仅包含零的数组,FFT返回一个具有多个非空条目的数组。此外,在完成iFFT之后,结果仍然包含具有非零虚部的条目。

    有人可以找到错误或向我提供一个我可以使解决方案更容易的提示吗?因为我希望它快,我不想做一个天真的乘法。 Karatsuba的算法会更好吗,因为我不需要复数?

1 个答案:

答案 0 :(得分:1)

我使用旧的Java实现检查了FFT(抱歉讨厌的代码)。我在FFT和DtoA函数中发现了以下内容:

  • 您的DtoA功能中缺少0
  • 你在向量中设置系数的顺序是相反的(也许这是出于某种原因故意)
  • &#34;结合&#34; Cooley Tukey算法的阶段不正确:术语的前半部分为a + b * c形式,后半部分为a - b * c形式。

以下代码效率不高,但应该清楚。

#include <iostream>
#include <vector>
#include <complex>
#include <exception>
#include <cmath>

using namespace std;

vector<complex<double>> fft(const vector<complex<double>> &x) {
    int N = x.size();

    // base case
    if (N == 1) return vector<complex<double>> { x[0] };

    // radix 2 Cooley-Tukey FFT
    if (N % 2 != 0) { throw new std::exception(); }

    // fft of even terms
    vector<complex<double>> even, odd;
    for (int k = 0; k < N/2; k++) {
        even.push_back(x[2 * k]);
        odd.push_back(x[2 * k + 1]);
    }
    vector<complex<double>> q = fft(even);
    vector<complex<double>> r = fft(odd);

    // combine
    vector<complex<double>> y;
    for (int k = 0; k < N/2; k++) {
        double kth = -2 * k * M_PI / N;
        complex<double> wk = complex<double>(cos(kth), sin(kth));
        y.push_back(q[k] + (wk * r[k]));
    }
    for (int k = 0; k < N/2; k++) {
        double kth = -2 * k * M_PI / N;
        complex<double> wk = complex<double>(cos(kth), sin(kth));
        y.push_back(q[k] - (wk * r[k])); // you didn't do this
    }
    return y;
}

vector<complex<double>> DtoA (double x)
{
    vector<complex<double>> ans(8);
    long X = long(x*100000000); // a 0 was missing  here
    ans[7] = complex<double>(double(X%10), 0.0); X/=10;
    ans[6] = complex<double>(double(X%10), 0.0); X/=10;
    ans[5] = complex<double>(double(X%10), 0.0); X/=10;
    ans[4] = complex<double>(double(X%10), 0.0); X/=10;
    ans[3] = complex<double>(double(X%10), 0.0); X/=10;
    ans[2] = complex<double>(double(X%10), 0.0); X/=10;
    ans[1] = complex<double>(double(X%10), 0.0); X/=10;
    ans[0] = complex<double>(double(X%10), 0.0);
    return ans;
}

int main ()
{
    double n = 0.1234;
    auto nComplex = DtoA(n);
    for (const auto &e : nComplex) {
        std::cout << e << " ";
    }
    std::cout << std::endl;

    try {
        auto nFFT = fft(nComplex);

        for (const auto &e : nFFT) {
            std::cout << e << " ";
        }
    }
    catch (const std::exception &e) {
        std::cout << "exception" << std::endl;
    }
  return 0;
}

程序的输出(我用Octave检查过它,它是一样的):

(1,0) (2,0) (3,0) (4,0) (0,0) (0,0) (0,0) (0,0) 
(10,0) (-0.414214,-7.24264) (-2,2) (2.41421,-1.24264) (-2,0) (2.41421,1.24264) (-2,-2) (-0.414214,7.24264)

希望这有帮助。

修改

关于逆FFT,您可以证明

iFFT(x) = (1 / N) conjugate( FFT( conjugate(x) ) )

其中N是数组x中元素的数量。所以你可以使用fft函数来计算ifft:

vector<complex<double>> ifft(const vector<complex<double>> &vec) {
    std::vector<complex<double>> conj;
    for (const auto &e : vec) {
        conj.push_back(std::conj(e));
    }

    std::vector<complex<double>> vecFFT = fft(conj);

    std::vector<complex<double>> result;
    for (const auto &e : vecFFT) {
        result.push_back(std::conj(e) / static_cast<double>(vec.size()));
    }

    return result;
}

以下是修改过的主要内容:

int main ()
{
    double n = 0.1234;
    auto nComplex = DtoA(n);
    for (const auto &e : nComplex) {
        std::cout << e << " ";
    }
    std::cout << std::endl;

    auto nFFT = fft(nComplex);
    for (const auto &e : nFFT)
        std::cout << e << " ";
    std::cout << std::endl;

    auto iFFT = ifft(nFFT);
    for (const auto &e : iFFT)
        std::cout << e << " ";
    return 0;
}

和输出:

(1,0) (2,0) (3,0) (4,0) (0,0) (0,0) (0,0) (0,0) 
(10,0) (-0.414214,-7.24264) (-2,2) (2.41421,-1.24264) (-2,0) (2.41421,1.24264) (-2,-2) (-0.414214,7.24264) 
(1,-0) (2,-1.08163e-16) (3,7.4688e-17) (4,2.19185e-16) (0,-0) (0,1.13882e-16) (0,-7.4688e-17) (0,-2.24904e-16)

请注意,像1e-16这样的数字几乎为0(硬件中的双精度数并不完美)。