Numpy:如何使用argmax结果获得实际最大值?

时间:2017-10-20 00:43:25

标签: python numpy

假设我有一个3D数组:

>>> a
array([[[7, 0],
        [3, 6]],

       [[2, 4],
        [5, 1]]])

我可以使用

argmax旁边获取axis=1
>>> m = np.argmax(a, axis=1)
>>> m
array([[0, 1],
       [1, 0]])

如何将m用作a的索引,以便结果等同于使用max

>>> a.max(axis=1)
array([[7, 6],
       [5, 4]])

(当m应用于其他相同形状的数组时,这很有用)

2 个答案:

答案 0 :(得分:3)

您可以使用advanced indexingnumpy broadcasting执行此操作:

m = np.argmax(a, axis=1)
a[np.arange(a.shape[0])[:,None], m, np.arange(a.shape[2])]

#array([[7, 6],
#       [5, 4]])
m = np.argmax(a, axis=1)

创建第1,第2和第3维度索引的数组:

ind1, ind2, ind3 = np.arange(a.shape[0])[:,None], m, np.arange(a.shape[2])
​

由于尺寸不匹配,三个阵列将广播,导致每个阵列如下:

for x in np.broadcast_arrays(ind1, ind2, ind3):
    print(x, '\n')

#[[0 0]
# [1 1]] 

#[[0 1]
# [1 0]] 

#[[0 1]
# [0 1]] 

由于所有索引都是整数数组,因此它会触发advanced indexing,因此会拾取索引为(0, 0, 0), (0, 1, 1), (1, 1, 0), (1, 0, 1)的元素,即每个数组中的一个元素组合为索引。

答案 1 :(得分:2)

您可以使用np.ogrid在阵列的所有轴上创建网格,除了缩小的网格。然后只需将argmax结果插入轴的位置,并将index数组插入结果:

>>> import numpy as np
>>> a = np.array([[[7, 0], [3, 6]], [[2, 4], [5, 1]]])
>>> axis = 1

>>> # Create the grid
>>> idx = list(np.ogrid[[slice(a.shape[ax]) for ax in range(a.ndim) if ax != axis]])
>>> argmaxes = np.argmax(a, axis=axis)
>>> idx.insert(axis, argmaxes)

>>> # Index the original array with the grid
>>> a[idx]
array([[7, 6],
       [5, 4]])

请注意,这不适用于axis=None,或者如果您减少了多个轴。