如何在列表数组上使用np argmax?

时间:2019-04-17 23:18:44

标签: python numpy

我有一些数据y_hat如下:

[[0. 1. 0. ... 0. 0. 0.]
 [0. 1. 0. ... 0. 0. 0.]
 [0. 1. 0. ... 0. 0. 0.]
 ...
 [0. 1. 0. ... 0. 0. 0.]
 [0. 1. 0. ... 0. 0. 0.]
 [0. 1. 0. ... 0. 0. 0.]]

我想获得每一行的argmax,以便得到一个像这样的向量:

[[3]
 [8]
 [8]
 ...
 [5]
 [1]
 [7]]

如果我只做np.argmax(y_hat),它将返回一个1

2 个答案:

答案 0 :(得分:1)

np.argmax accepts an axis keyword argument。使用它。

列为axis=0,行为axis=1

答案 1 :(得分:1)

这是argmax之后numpy广播的一种方式

a.argmax(axis = 1)[:,None]

a[:,None].argmax(-1)