如何在ndarray中找到所有argmax

时间:2016-11-30 20:42:32

标签: python numpy argmax

我有一个2维NumPy ndarray。

array([[  0.,  20.,  -2.],
   [  2.,   1.,   0.],
   [  4.,   3.,  20.]])

如何获得最大元素的所有索引?所以我想作为输出数组([0,1],[2,2])。

1 个答案:

答案 0 :(得分:2)

max-equality mask -

上使用np.argwhere
np.argwhere(a == a.max())

示例运行 -

In [552]: a   # Input array
Out[552]: 
array([[  0.,  20.,  -2.],
       [  2.,   1.,   0.],
       [  4.,   3.,  20.]])

In [553]: a == a.max() # Max equality mask
Out[553]: 
array([[False,  True, False],
       [False, False, False],
       [False, False,  True]], dtype=bool)

In [554]: np.argwhere(a == a.max()) # array of row, col indices of max-mask
Out[554]: 
array([[0, 1],
       [2, 2]])

如果使用浮点数,可能需要在那里使用一些容差。因此,考虑到这一点,您可以使用具有一些默认绝对和相对容差值的np.isclose。这将替换早期的a == a.max()部分,如此 -

In [555]: np.isclose(a, a.max())
Out[555]: 
array([[False,  True, False],
       [False, False, False],
       [False, False,  True]], dtype=bool)