有谁知道计算卷积的最快方法?不幸的是,我处理的矩阵非常大(500x500x200),如果我在MATLAB中使用convn
则需要很长时间(我必须在嵌套循环中迭代这个计算)。所以,我使用FFT进行卷积,现在速度更快。但是,我仍然在寻找一种更快的方法。有什么想法吗?
答案 0 :(得分:16)
如果你的内核是可分离的,那么通过执行多个连续的1D卷积可以实现最大的速度增益。
MathWorks的Steve Eddins描述了当内核在his blog的MATLAB上下文中可分离时,如何利用卷积的相关性来加速卷积。对于P-by-Q
内核,执行两个单独和顺序卷积与2D卷积的计算优势是PQ/(P+Q)
,对应于9x9内核的4.5x和15x15内核的~11x。 编辑:this Q&A给出了一个有趣的无意识的证明这种差异的证据。
要确定内核是否可分离(即两个向量的外积)博客goes on to describe如何检查内核是否可与SVD分离以及如何获取1D内核。他们的例子是2D内核。对于N维可分卷积的解决方案,请检查this FEX submission。
值得指出的另一个资源是this SIMD (SSE3/SSE4) implementation of 3D convolution by Intel,其中包括source和presentation。代码用于16位整数。除非你转向GPU(例如cuFFT),否则很难比英特尔的实施更快,后者还包括Intel MKL。 this page of the MKL documentation底部有一个3D卷积(单精度浮点数)示例(链接已修复,现在在https://stackoverflow.com/a/27074295/2778484中镜像)。
答案 1 :(得分:2)
您可以尝试重叠添加和重叠保存方法。它们涉及将输入信号分解为较小的块,然后使用上述任一方法。
FFT最有可能 - 我可能错了 - 这是最快的方法,特别是如果你在MATLAB中使用内置例程或在C ++中使用库。除此之外,将输入信号分解为更小的块应该是一个不错的选择。
答案 2 :(得分:0)
我有2种方法来计算fastconv
和2比1
1-犰狳 您可以使用armadillo库来使用此代码来加密转换
cx_vec signal(1024,fill::randn);
cx_vec code(300,fill::randn);
cx_vec ans = conv(signal,code);
2-use fftw ans sigpack和armadillo library用于以这种方式调用快速转换你必须在构造函数中初始化你的代码
FastConvolution::FastConvolution(cx_vec inpCode)
{
filterCode = inpCode;
fft_w = NULL;
}
cx_vec FastConvolution::filter(cx_vec inpData)
{
int length = inpData.size()+filterCode.size();
if((length & (length - 1)) == 0)
{
}
else
{
length = pow(2 , (int)log2(length) + 1);
}
if(length != fftCode.size())
initCode(length);
static cx_vec zeroPadedData;
if(length!= zeroPadedData.size())
{
zeroPadedData.resize(length);
}
zeroPadedData.fill(0);
zeroPadedData.subvec(0,inpData.size()-1) = inpData;
cx_vec fftSignal = fft_w->fft_cx(zeroPadedData);
cx_vec mullAns = fftSignal % fftCode;
cx_vec ans = fft_w->ifft_cx(mullAns);
return ans.subvec(filterCode.size(),inpData.size()+filterCode.size()-1);
}
void FastConvolution::initCode(int length)
{
if(fft_w != NULL)
{
delete fft_w;
}
fft_w = new sp::FFTW(length,FFTW_ESTIMATE);
cx_vec conjCode(length,fill::zeros);
fftCode.resize(length);
for(int i = 0; i < filterCode.size();i++)
{
conjCode.at(i) = filterCode.at(filterCode.size() - i - 1);
}
conjCode = conj(conjCode);
fftCode = fft_w->fft_cx(conjCode);
}