我有一些数据y_hat
如下:
[[0. 1. 0. ... 0. 0. 0.]
[0. 1. 0. ... 0. 0. 0.]
[0. 1. 0. ... 0. 0. 0.]
...
[0. 1. 0. ... 0. 0. 0.]
[0. 1. 0. ... 0. 0. 0.]
[0. 1. 0. ... 0. 0. 0.]]
我想获得每一行的argmax
,以便得到一个像这样的向量:
[[3]
[8]
[8]
...
[5]
[1]
[7]]
如果我只做np.argmax(y_hat)
,它将返回一个1
。
答案 0 :(得分:1)
np.argmax
accepts an axis
keyword argument。使用它。
列为axis=0
,行为axis=1
答案 1 :(得分:1)
这是argmax
之后numpy
广播的一种方式
a.argmax(axis = 1)[:,None]
或
a[:,None].argmax(-1)