我正在尝试对整个3D或4D阵列执行某些操作,但仅限于较小尺寸的子组(较大阵列中包含的2D阵列)。
示例:
input = np.arange(75).reshape((3, 5, 5)) # or any other 3D or 4D matrix.
mask_hor = np.arange(-1, 2)
mask_ver = mask_hor[:, None]
output = np.zeros((3, 3, 3))
for i in range(1, 5):
for j in range(1, 5):
output[:, i-1, j-1] = foo(input[:, i+mask_ver, j+mask_hor])
其中foo是对输入的某种操纵
我的问题是: 是否有一个方法/掩码,我可以传递给输入,以便我可以摆脱嵌套的for循环?我主要是在寻找加速。
谢谢!
答案 0 :(得分:1)
这比任何优雅的东西都更快速,更脏。为了论证,我们将窗口中的9个元素作为foo
函数求和。
import numpy as np
from scipy import ndimage
# take the sum of a 3x3 window of a matrix
def foo_lin_mat(mat):
return mat.sum(axis=(-2, -1)) # sum over the last two axes
# sum up the individual matrices
def foo_lin_nine(m1, m2, m3, m4, m5, m6, m7, m8, m9):
return m1 + m2 + m3 + m4 + m5 + m6 + m7 + m8 + m9
# compute foo on an input matrix by shifting the mask around
def nestfor(input, foo):
depth, n, m = input.shape
output = np.zeros((depth, n - 2, m - 2))
mask_hor = np.arange(-1, 2)
mask_ver = mask_hor[:, None]
for i in range(1, n - 1):
for j in range(1, m - 1):
output[:, i - 1, j - 1] = foo(input[:, i + mask_ver, j + mask_hor])
return output
# compute foo on an input matrix by breaking the input matrix into 9 submatrices
def flatargs(input, foo):
depth, n, m = input.shape
return foo(input[:, :n-2, :m-2],
input[:, 1:n-1, :m-2],
input[:, 2:, :m-2],
input[:, :n-2, 1:m-1],
input[:, 1:n-1, 1:m-1],
input[:, 2:, 1:m-1],
input[:, :n-2, 2:],
input[:, 1:n-1, 2:],
input[:, 2:, 2:], )
# compute the sum of a window using ndimage.convolve
def convolve(input, mask):
mask = np.ones((1, 3, 3))
out = ndimage.convolve(input, mask)
# cut off the outer edges
return out[1:-1, 1:-1]
所以我们有三个函数可以采用矩阵并总结各个3x3窗口。我已经确认他们最后会吐出相同的矩阵。至于基准测试
In [62]: %timeit nestfor(input, foo_lin_mat)
1000 loops, best of 3: 261 µs per loop
In [63]: %timeit flatargs(input, foo_lin_nine)
10000 loops, best of 3: 35.8 µs per loop
In [66]: mask = np.ones((1,3,3))
In [69]: %timeit convolve(input, mask)
The slowest run took 6.12 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 42.2 µs per loop
即。 flatargs
版本比原始的嵌套for循环版本快约7倍。
如果您的foo
函数是输入窗口的线性函数,您还可以使用ndimage.convolve
函数进行窗口化,如此处的convolve
函数。阅读最终代码可能会更容易,但您必须小心使用掩码所使用的数组。