我有这个数组:
arr = np.array([3, 7, 4])
这些布尔索引:
cond = np.array([False, True, True])
我想在布尔条件为真的数组中找到最大值的索引。所以我这样做:
np.ma.array(arr, mask=~cond).argmax()
哪个有效并返回1.但如果我有一个布尔索引数组:
cond = np.array([[False, True, True], [True, False, True]])
是否存在迭代/ numpy方式迭代布尔索引数组以返回[1,2]?
答案 0 :(得分:3)
对于argmax
的特殊用例,您可以使用Font Awesome并将屏蔽值设置为负无穷大:
>>> inf = np.iinfo('i8').max
>>> np.where(cond, arr, -inf).argmax(axis=1)
array([1, 2])
或者,您可以使用np.where
手动广播:
>>> np.ma.array(np.tile(arr, 2).reshape(2, 3), mask=~cond).argmax(axis=1)
array([1, 2])
答案 1 :(得分:2)
所以你想要一个矢量化版本:
In [302]: [np.ma.array(arr,mask=~c).argmax() for c in cond]
Out[302]: [1, 2]
cond
的实际尺寸是多少?如果行数与列相比较小(或arr
的长度),则这样的迭代可能并不昂贵。
tile
的{p> https://stackoverflow.com/a/31767220/901925看起来不错。在这里我稍微改变一下:
In [308]: np.ma.array(np.tile(arr,(cond.shape[0],1)),mask=~cond).argmax(axis=1)
Out[308]: array([1, 2], dtype=int32)
正如预期的那样,列表理解时间与cond
行成比例,而平铺方法只比单行情况慢一点。但随着92.7 µs
周围的时间,这种蒙面数组方法比arr.argmax()
慢得多。掩蔽增加了很多开销。
where
版本的速度要快得多
np.where(cond, arr, -100).argmax(1) # 20 µs
建议删除的答案
(arr*cond).argmax(1) # 8 µs
哪个更快。正如所提出的,如果存在负arr
值,则它不起作用。但它可能会被调整以处理这些。
答案 2 :(得分:0)
arr = np.array([3, 7, 4])
cond = np.array([[False, True, True], [True, False, True]])
def multi_slice_max(bool_arr , x ):
return np.ma.array(x, mask=~bool_arr).argmax()
np.apply_along_axis(multi_slice_max , 1 , cond , arr)