Python:查找特定维度上的最大数组索引,该索引大于阈值

时间:2017-01-06 22:12:50

标签: python arrays performance numpy

假设我有一个4-D numpy数组(例如:np.rand((x,y,z,t)))数据,其维度对应于X,Y,Z和时间。

对于每个X和Y点,以及在每个时间步,我想找到Z中最大的索引,其中数据大于某个阈值n

所以我的最终结果应该是X-by-Y-by-t数组。 Z列中没有值大于阈值的实例应该用0表示。

我可以循环逐个元素并构建一个新的数组,但是我在一个非常大的数组上运行它需要太长时间。

2 个答案:

答案 0 :(得分:3)

不幸的是,按照Python内置的例子,numpy并不容易获得 last 索引,尽管 first 是微不足道的。还是,像

hammerspace

给了我

def slow(arr, axis, threshold):
    return (arr > threshold).cumsum(axis=axis).argmax(axis=axis)

def fast(arr, axis, threshold):
    compare = (arr > threshold)
    reordered = compare.swapaxes(axis, -1)
    flipped = reordered[..., ::-1]
    first_above = flipped.argmax(axis=-1)
    last_above = flipped.shape[-1] - first_above - 1
    are_any_above = compare.any(axis=axis)
    # patch the no-matching-element found values
    patched = np.where(are_any_above, last_above, 0)
    return patched

(可能有一种更轻松的方式来进行翻转,但这是一天的结束,我的大脑正在关闭。: - )

答案 1 :(得分:2)

这是更快的方法 -

def faster(a,n,invalid_specifier):
    mask = a>n    
    idx = a.shape[2] - (mask[:,:,::-1]).argmax(2) - 1
    idx[~mask[:,:,-1] & (idx == a.shape[2]-1)] = invalid_specifier  
    return idx

运行时测试 -

# Using @DSM's benchmarking setup

In [553]: a = np.random.random((100,100,30,50))
     ...: n = 0.75
     ...: 

In [554]: out1 = faster(a,n,invalid_specifier=0)
     ...: out2 = fast(a, axis=2, threshold=n) # @DSM's soln
     ...: 

In [555]: np.allclose(out1,out2)
Out[555]: True

In [556]: %timeit fast(a, axis=2, threshold=n)  # @DSM's soln
10 loops, best of 3: 64.6 ms per loop

In [557]: %timeit faster(a,n,invalid_specifier=0)
10 loops, best of 3: 43.7 ms per loop