我有一个2维NumPy ndarray。
array([[ 0., 20., -2.],
[ 2., 1., 0.],
[ 4., 3., 20.]])
如何获得最大元素的所有索引?所以我想作为输出数组([0,1],[2,2])。
答案 0 :(得分:2)
在 max-equality mask -
上使用np.argwhere
np.argwhere(a == a.max())
示例运行 -
In [552]: a # Input array
Out[552]:
array([[ 0., 20., -2.],
[ 2., 1., 0.],
[ 4., 3., 20.]])
In [553]: a == a.max() # Max equality mask
Out[553]:
array([[False, True, False],
[False, False, False],
[False, False, True]], dtype=bool)
In [554]: np.argwhere(a == a.max()) # array of row, col indices of max-mask
Out[554]:
array([[0, 1],
[2, 2]])
如果使用浮点数,可能需要在那里使用一些容差。因此,考虑到这一点,您可以使用具有一些默认绝对和相对容差值的np.isclose
。这将替换早期的a == a.max()
部分,如此 -
In [555]: np.isclose(a, a.max())
Out[555]:
array([[False, True, False],
[False, False, False],
[False, False, True]], dtype=bool)