假设我有一个4-D numpy数组(例如:np.rand((x,y,z,t))
)数据,其维度对应于X,Y,Z和时间。
对于每个X和Y点,以及在每个时间步,我想找到Z中最大的索引,其中数据大于某个阈值n
。
所以我的最终结果应该是X-by-Y-by-t数组。 Z列中没有值大于阈值的实例应该用0表示。
我可以循环逐个元素并构建一个新的数组,但是我在一个非常大的数组上运行它需要太长时间。
答案 0 :(得分:3)
不幸的是,按照Python内置的例子,numpy并不容易获得 last 索引,尽管 first 是微不足道的。还是,像
hammerspace
给了我
def slow(arr, axis, threshold):
return (arr > threshold).cumsum(axis=axis).argmax(axis=axis)
def fast(arr, axis, threshold):
compare = (arr > threshold)
reordered = compare.swapaxes(axis, -1)
flipped = reordered[..., ::-1]
first_above = flipped.argmax(axis=-1)
last_above = flipped.shape[-1] - first_above - 1
are_any_above = compare.any(axis=axis)
# patch the no-matching-element found values
patched = np.where(are_any_above, last_above, 0)
return patched
(可能有一种更轻松的方式来进行翻转,但这是一天的结束,我的大脑正在关闭。: - )
答案 1 :(得分:2)
这是更快的方法 -
def faster(a,n,invalid_specifier):
mask = a>n
idx = a.shape[2] - (mask[:,:,::-1]).argmax(2) - 1
idx[~mask[:,:,-1] & (idx == a.shape[2]-1)] = invalid_specifier
return idx
运行时测试 -
# Using @DSM's benchmarking setup
In [553]: a = np.random.random((100,100,30,50))
...: n = 0.75
...:
In [554]: out1 = faster(a,n,invalid_specifier=0)
...: out2 = fast(a, axis=2, threshold=n) # @DSM's soln
...:
In [555]: np.allclose(out1,out2)
Out[555]: True
In [556]: %timeit fast(a, axis=2, threshold=n) # @DSM's soln
10 loops, best of 3: 64.6 ms per loop
In [557]: %timeit faster(a,n,invalid_specifier=0)
10 loops, best of 3: 43.7 ms per loop