假设我有一个3D数组:
>>> a
array([[[7, 0],
[3, 6]],
[[2, 4],
[5, 1]]])
我可以使用
在argmax
旁边获取axis=1
>>> m = np.argmax(a, axis=1)
>>> m
array([[0, 1],
[1, 0]])
如何将m
用作a
的索引,以便结果等同于使用max
?
>>> a.max(axis=1)
array([[7, 6],
[5, 4]])
(当m
应用于其他相同形状的数组时,这很有用)
答案 0 :(得分:3)
您可以使用advanced indexing和numpy broadcasting执行此操作:
m = np.argmax(a, axis=1)
a[np.arange(a.shape[0])[:,None], m, np.arange(a.shape[2])]
#array([[7, 6],
# [5, 4]])
m = np.argmax(a, axis=1)
创建第1,第2和第3维度索引的数组:
ind1, ind2, ind3 = np.arange(a.shape[0])[:,None], m, np.arange(a.shape[2])
由于尺寸不匹配,三个阵列将广播,导致每个阵列如下:
for x in np.broadcast_arrays(ind1, ind2, ind3):
print(x, '\n')
#[[0 0]
# [1 1]]
#[[0 1]
# [1 0]]
#[[0 1]
# [0 1]]
由于所有索引都是整数数组,因此它会触发advanced indexing,因此会拾取索引为(0, 0, 0), (0, 1, 1), (1, 1, 0), (1, 0, 1)
的元素,即每个数组中的一个元素组合为索引。
答案 1 :(得分:2)
您可以使用np.ogrid
在阵列的所有轴上创建网格,除了缩小的网格。然后只需将argmax
结果插入轴的位置,并将index数组插入结果:
>>> import numpy as np
>>> a = np.array([[[7, 0], [3, 6]], [[2, 4], [5, 1]]])
>>> axis = 1
>>> # Create the grid
>>> idx = list(np.ogrid[[slice(a.shape[ax]) for ax in range(a.ndim) if ax != axis]])
>>> argmaxes = np.argmax(a, axis=axis)
>>> idx.insert(axis, argmaxes)
>>> # Index the original array with the grid
>>> a[idx]
array([[7, 6],
[5, 4]])
请注意,这不适用于axis=None
,或者如果您减少了多个轴。