我有以下代码产生两个矩阵的卷积。我遇到的问题是卷积很耗费内存。关于如何加快速度的任何想法?
Matlab是否有更友好的版本?我应该预先分配到某个地方吗?
function res = fftconv(data,query)
N = size(data,1);
R = size(query,1);
C = size(query,2);
query(end+1:N,end+1:N)=0;
temp = ifft2(fft2(data).*fft2(query));
res = temp(R:end,C:end);
end
答案 0 :(得分:1)
您的方法可能会意外地计算“坏”长度的FFT,即具有较大素因子的数字。
此外,您的方法进行循环卷积:它不匹配Matlab的内置conv2
的输出而没有零填充。 (回想一下,当你将两个输入都归零到nx + ny - 1
时,使用FFT的循环卷积等同于时域线性卷积。)
这是一个可以使用的简单函数,它返回与conv2
相同的值:
function z = conv2fft(x, y, nfft)
nx = size(x);
ny = size(y);
nz = nx + ny - 1;
if ~exist('nfft', 'var') || isempty(nfft)
nfft = 2 .^ nextpow2(nz);
else
assert(all(nfft >= nz), 'nfft >= nx + ny - 1 for linear convolution');
end
zfull = ifft2(fft2(x, nfft(1), nfft(2)) .* fft2(y, nfft(1), nfft(2)));
z = zfull(1 : nz(1), 1 : nz(2));
检查出来,它有效:
>> x = randn(10, 11);
>> y = randn(4, 3);
>> z1 = conv2(x, y);
>> z2 = conv2fft(x, y);
>> max(abs(z2(:) - z1(:)))
ans =
2.2204e-15
两者之间的误差非常小,即使对于矩形输入也是如此。您需要对数据进行基准测试,以确认它更快。
关于速度的一个重要警告:如果没有提供,则此函数使用2的幂的默认nfft
。有时这不是最好的。例如,如果nx + ny - 1
是[1025, 1025]
(即conv2
的输出是1025乘1025),则默认将导致2048乘2048个中间数组,这可能比1025慢1025 !这是因为FFTW内部必须分配四倍的内存,并且需要4倍的FFT。 如果您知道这种情况,您可以conv2fft
更好nfft
,例如[1080, 1080]
(1080的唯一因素是2,3和5 )。 Julia有一个很好的函数nextprod
,可以让你找到下一个具有某些因子的整数。这是free Matlab version of nextprod
,您可以使用nextprod([2 3 5], 1025)
。这将返回1080。
总结: