我正在使用pyFFTW对二维复数数组进行2D FFT。这些数组可能会变得非常大(~128 GiB),因此执行时间至关重要。 (背景是光学物理学中的波前传播。)
查看以下玩具代码:
import numpy as np
import pyfftw
import multiprocessing
a = np.random.rand(16384, 16384) + 1j*np.random.rand(16384, 16384)
fft = pyfftw.FFTW(a, a, axes = (0, 1), direction = 'FFTW_FORWARD', flags = ('FFTW_ESTIMATE', 'FFTW_UNALIGNED', 'FFTW_DESTROY_INPUT',), threads = multiprocessing.cpu_count())
a = fft()
FFT需要几秒钟才能在我现代的64位机器上执行。
当分两步执行2D FFT(所有列和所有行的1D-FFT)时,结果和执行时间都保持不变:
fft = pyfftw.FFTW(a, a, axes = (0,), direction = 'FFTW_FORWARD', flags = ('FFTW_ESTIMATE', 'FFTW_UNALIGNED', 'FFTW_DESTROY_INPUT',), threads = multiprocessing.cpu_count())
a = fft()
fft = pyfftw.FFTW(a, a, axes = (1,), direction = 'FFTW_FORWARD', flags = ('FFTW_ESTIMATE', 'FFTW_UNALIGNED', 'FFTW_DESTROY_INPUT',), threads = multiprocessing.cpu_count())
a = fft()
然而,单独花费这些步骤的时间表明列-FFT比行FFT 慢大约10倍。
我想,原因是数组被逐行保存到物理RAM中。实际上,a.flags给出了
C_CONTIGUOUS : True
F_CONTIGUOUS : False
OWNDATA : True
WRITEABLE : True
ALIGNED : True
UPDATEIFCOPY : False
而a.strides给出
(262144, 16)
因此,数组是C连续的,似乎是正确对齐的。但是,删除标志'FFTW_UNALIGNED'会使列FFT大约慢10倍(而行FFT变得稍快)。
因此,我的问题是:
对齐可能有问题,或者对列的访问速度比对C连续数组的行的物理限制要慢10倍?
编辑:事实上,因素10似乎太大了。让我们比较行和列的简单读/写访问:
a[:,0:16384:2]*=1j
和
a[0:16384:2,:]*=1j
将列与偶数索引相乘(第一个变体)比使用偶数索引(第二个变体)乘以行的速度慢约2倍。
编辑:在ipython中输入的确切代码是
In [1]: import pyfftw
In [2]: import multiprocessing
In [3]: a = np.random.rand(16384, 16384) + 1j*np.random.rand(16384, 16384)
In [4]: fft = pyfftw.FFTW(a, a, axes = (0, 1), direction = 'FFTW_FORWARD', flags = ('FFTW_ESTIMATE', 'FFTW_UNALIGNED', 'FFTW_DESTROY_INPUT',), threads = multiprocessing.cpu_count())
In [5]: %timeit a = fft()