查找2D阵列中四个邻居的最大值

时间:2018-01-08 01:22:47

标签: python performance numpy scipy

我想在2D numpy数组中找到项目的四个邻居的最大值。我提出的第一个解决方案是使用scipy.ndimage.generic_filter

import numpy as np
import scipy.ndimage

a = np.random.uniform(low=10, high=100, size=(6500,6500)).astype(np.float)

def filter2d(footprint_elements):
    return max(footprint_elements)

footprint = np.array([[0, 1, 0],
                       [1, 0, 1],
                       [0, 1, 0]])

maxs = scipy.ndimage.generic_filter(a,filter2d, footprint=footprint)

这里使用通用过滤器的问题是它非常慢,所以我想出了一个更快的解决方案(请注意边缘并不重要):

maxs = np.maximum.reduce([a[:-2, 1:-1], a[1:-1, 2:], a[2:, 1:-1], a[1:-1,:-2]])

我正在寻找可能更快的任何方法。

我不确定是否考虑到这可能会提高速度,但我只对特定项目(由另一个数组确定)感兴趣。例如,查找数组a中数组b大于0的那些项的最大邻居:

b = np.random.uniform(low=-10, high=10, size=(6500,6500)).astype(np.float)

# need to find maximum neighbour of array a where b > 0

 maxs[b > 0]

1 个答案:

答案 0 :(得分:2)

您可以使用this blog post而不是Function<A, R> left, Function<B, R> right。它还接受scipy.ndimage.generic_filter参数:

footprint

定时:

from scipy.ndimage import maximum_filter

maxs = maximum_filter(a, footprint=footprint)

所以In [105]: a = np.random.uniform(low=10, high=100, size=(6500,6500)) In [106]: %timeit maxs = maximum_filter(a, footprint=footprint) 858 ms ± 2.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) In [107]: %timeit maxs = np.maximum.reduce([a[:-2, 1:-1], a[1:-1, 2:], a[2:, 1:-1], a[1:-1,:-2]]) 1.34 s ± 12.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) 比使用应用于切片的maximum_filter快一点。