3D或4D Numpy阵列使用2D mask& amp;整个矩阵运算

时间:2017-09-02 13:28:32

标签: python arrays python-3.x numpy

我正在尝试对整个3D或4D阵列执行某些操作,但仅限于较小尺寸的子组(较大阵列中包含的2D阵列)。

示例:

input = np.arange(75).reshape((3, 5, 5))  # or any other 3D or 4D matrix. 
mask_hor = np.arange(-1, 2)
mask_ver = mask_hor[:, None]

output = np.zeros((3, 3, 3))

for i in range(1, 5):
  for j in range(1, 5):
    output[:, i-1, j-1] = foo(input[:, i+mask_ver, j+mask_hor])

其中foo是对输入的某种操纵

我的问题是: 是否有一个方法/掩码,我可以传递给输入,以便我可以摆脱嵌套的for循环?我主要是在寻找加速。

谢谢!

1 个答案:

答案 0 :(得分:1)

这比任何优雅的东西都更快速,更脏。为了论证,我们将窗口中的9个元素作为foo函数求和。

import numpy as np
from scipy import ndimage

# take the sum of a 3x3 window of a matrix
def foo_lin_mat(mat):
    return mat.sum(axis=(-2, -1))  # sum over the last two axes

# sum up the individual matrices
def foo_lin_nine(m1, m2, m3, m4, m5, m6, m7, m8, m9):
    return m1 + m2 + m3 + m4 + m5 + m6 + m7 + m8 + m9

# compute foo on an input matrix by shifting the mask around
def nestfor(input, foo):
    depth, n, m = input.shape
    output = np.zeros((depth, n - 2, m - 2))
    mask_hor = np.arange(-1, 2)
    mask_ver = mask_hor[:, None]
    for i in range(1, n - 1):
        for j in range(1, m - 1):
            output[:, i - 1, j - 1] = foo(input[:, i + mask_ver, j + mask_hor])
    return output

# compute foo on an input matrix by breaking the input matrix into 9 submatrices 
def flatargs(input, foo):
    depth, n, m = input.shape
    return foo(input[:, :n-2, :m-2], 
               input[:, 1:n-1, :m-2],
               input[:, 2:, :m-2],
               input[:, :n-2, 1:m-1],
               input[:, 1:n-1, 1:m-1],
               input[:, 2:, 1:m-1],
               input[:, :n-2, 2:],
               input[:, 1:n-1, 2:],
               input[:, 2:, 2:], )

# compute the sum of a window using ndimage.convolve
def convolve(input, mask):
    mask = np.ones((1, 3, 3))
    out = ndimage.convolve(input, mask)
    # cut off the outer edges
    return out[1:-1, 1:-1]

所以我们有三个函数可以采用矩阵并总结各个3x3窗口。我已经确认他们最后会吐出相同的矩阵。至于基准测试

In [62]: %timeit  nestfor(input, foo_lin_mat)
1000 loops, best of 3: 261 µs per loop

In [63]: %timeit flatargs(input, foo_lin_nine)
10000 loops, best of 3: 35.8 µs per loop

In [66]: mask = np.ones((1,3,3))

In [69]: %timeit convolve(input, mask)
The slowest run took 6.12 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 42.2 µs per loop

即。 flatargs版本比原始的嵌套for循环版本快约7倍。

如果您的foo函数是输入窗口的线性函数,您还可以使用ndimage.convolve函数进行窗口化,如此处的convolve函数。阅读最终代码可能会更容易,但您必须小心使用掩码所使用的数组。