我有一个形状为(150, 9, 5)
的3d numpy数组,我想获得第3维的argmax(在元素之外),但我无法完成。
例如,假设数组看起来像(此示例未遵循我在代码中使用的形状):
[[[615, 142, 18]],
[[12, 412, 98]],
[[5, 11, 162]],
[[19, 76, 7]],
[[56, 78, 12]],
[[119, 41, 55]],
[[26, 7, 15]],
[[19, 67, 53]]]
我想找回形状数组(150、9、1)(同样,在我的情况下,与示例无关)。对于该示例,将为:
[[[0]],
[[1]],
[[2]],
[[1]],
[[1]],
[[0]],
[[0]],
[[1]]]
当我同时使用np.argmax()
和axis 0 and 1
时,我得到了错误的结果。
有没有一种方法可以直接解决这个问题,还是应该使用for循环遍历(9,5)?
答案 0 :(得分:0)
您有带有额外尺寸的额外轴:
import numpy as np
a=np.array([[[615, 142, 18]],
[[12, 412, 98]],
[[5, 11, 162]],
[[19, 76, 7]],
[[56, 78, 12]],
[[119, 41, 55]],
[[26, 7, 15]],
[[19, 67, 53]]])
a.argmax(axis=2)
输出:
Out[82]:
array([[0],
[1],
[2],
[1],
[1],
[0],
[0],
[1]], dtype=int64)