CuFFT Double to Complex

时间:2014-07-17 16:45:10

标签: cuda cufft

我想用CuFFT Lib从double到std :: complex进行FFT。我的代码看起来像

#include <complex>
#include <iostream>
#include <cufft.h>
#include <cuda_runtime_api.h>

typedef std::complex<double> Complex;
using namespace std;

int main(){
  int n = 100;
  double* in;
  Complex* out;
  in = (double*) malloc(sizeof(double) * n);
  out = (Complex*) malloc(sizeof(Complex) * n/2+1);
  for(int i=0; i<n; i++){
     in[i] = 1;
  }

  cufftHandle plan;
  plan = cufftPlan1d(&plan, n, CUFFT_D2Z, 1);
  unsigned int mem_size = sizeof(double)*n;
  cufftDoubleReal *d_in;
  cufftDoubleComplex *d_out;
  cudaMalloc((void **)&d_in, mem_size);
  cudaMalloc((void **)&d_out, mem_size);
  cudaMemcpy(d_in, in, mem_size, cudaMemcpyHostToDevice);
  cudaMemcpy(d_out, out, mem_size, cudaMemcpyHostToDevice);
  int succes = cufftExecD2Z(plan,(cufftDoubleReal *) d_in,(cufftDoubleComplex *) d_out);
  cout << succes << endl;
  cudaMemcpy(out, d_out, mem_size, cudaMemcpyDeviceToHost);

  for(int i=0; i<n/2; i++){
     cout << "out: " << i << " "  << out[i].real() << " " <<  out[i].imag() << endl;
  }
  return 0;
}

但在我看来这肯定是错的,因为我认为变换后的值应该是1 0 0 0 0 ....或没有归一化100 0 0 0 0 ....但我只是得到0 0 0 0 0 ......

此外,如果cufftExecD2Z可以在适当的位置工作,我希望更多,这应该是可能的,但我还没有弄清楚如何正确地这样做。有人可以帮忙吗?

1 个答案:

答案 0 :(得分:1)

您的代码有各种错误。您应该查看cufft documentation以及示例代码。

  1. 您应该对所有API返回值进行适当的cuda错误检查和正确的cufft错误检查。
  2. cufftPlan1d函数的返回值不会进入计划:

    plan = cufftPlan1d(&plan, n, CUFFT_D2Z, 1);
    

    函数本身设置计划(这就是你将&plan传递给函数的原因),然后当你将返回值分配给计划时,它会破坏函数设置的计划。

  3. 您正确识别出输出的大小为((N/2)+1),但是您没有在主机端正确分配空间:

    out = (Complex*) malloc(sizeof(Complex) * n/2+1);
    

    或在设备方面:

    unsigned int mem_size = sizeof(double)*n;
    ...
    cudaMalloc((void **)&d_out, mem_size);
    
  4. 以下代码修复了上述一些问题,足以得到您想要的结果(100,0,0,...)

    #include <complex>
    #include <iostream>
    #include <cufft.h>
    #include <cuda_runtime_api.h>
    
    #define cudaCheckErrors(msg) \
        do { \
            cudaError_t __err = cudaGetLastError(); \
            if (__err != cudaSuccess) { \
                fprintf(stderr, "Fatal error: %s (%s at %s:%d)\n", \
                    msg, cudaGetErrorString(__err), \
                    __FILE__, __LINE__); \
                fprintf(stderr, "*** FAILED - ABORTING\n"); \
                exit(1); \
            } \
        } while (0)
    
    
    typedef std::complex<double> Complex;
    using namespace std;
    
    int main(){
      int n = 100;
      double* in;
      Complex* out;
    #ifdef IN_PLACE
      in = (double*) malloc(sizeof(Complex) * (n/2+1));
      out = (Complex*)in;
    #else
      in = (double*) malloc(sizeof(double) * n);
      out = (Complex*) malloc(sizeof(Complex) * (n/2+1));
    #endif
      for(int i=0; i<n; i++){
         in[i] = 1;
      }
    
      cufftHandle plan;
      cufftResult res = cufftPlan1d(&plan, n, CUFFT_D2Z, 1);
      if (res != CUFFT_SUCCESS)  {cout << "cufft plan error: " << res << endl; return 1;}
      cufftDoubleReal *d_in;
      cufftDoubleComplex *d_out;
      unsigned int out_mem_size = (n/2 + 1)*sizeof(cufftDoubleComplex);
    #ifdef IN_PLACE
      unsigned int in_mem_size = out_mem_size;
      cudaMalloc((void **)&d_in, in_mem_size);
      d_out = (cufftDoubleComplex *)d_in;
    #else
      unsigned int in_mem_size = sizeof(cufftDoubleReal)*n;
      cudaMalloc((void **)&d_in, in_mem_size);
      cudaMalloc((void **)&d_out, out_mem_size);
    #endif
      cudaCheckErrors("cuda malloc fail");
      cudaMemcpy(d_in, in, in_mem_size, cudaMemcpyHostToDevice);
      cudaCheckErrors("cuda memcpy H2D fail");
      res = cufftExecD2Z(plan,d_in, d_out);
      if (res != CUFFT_SUCCESS)  {cout << "cufft exec error: " << res << endl; return 1;}
      cudaMemcpy(out, d_out, out_mem_size, cudaMemcpyDeviceToHost);
      cudaCheckErrors("cuda memcpy D2H fail");
    
      for(int i=0; i<n/2; i++){
         cout << "out: " << i << " "  << out[i].real() << " " <<  out[i].imag() << endl;
      }
      return 0;
    }
    

    回顾the documentation关于在真实到复杂情况下进行就地变换的必要条件。可以使用-DIN_PLACE重新编译上述代码,以查看就地转换的行为以及必要的代码更改。