在numpy数组的第3维上使用numpy argmax

时间:2019-05-03 16:34:28

标签: python arrays numpy

我有一个形状为(150, 9, 5)的3d numpy数组,我想获得第3维的argmax(在元素之外),但我无法完成。 例如,假设数组看起来像(此示例未遵循我在代码中使用的形状):

[[[615, 142, 18]],

   [[12, 412, 98]],

   [[5, 11, 162]],

   [[19, 76, 7]],

   [[56, 78, 12]],

   [[119, 41, 55]],

   [[26, 7, 15]],

   [[19, 67, 53]]]

我想找回形状数组(150、9、1)(同样,在我的情况下,与示例无关)。对于该示例,将为:

[[[0]],

   [[1]],

   [[2]],

   [[1]],

   [[1]],

   [[0]],

   [[0]],

   [[1]]]

当我同时使用np.argmax()axis 0 and 1时,我得到了错误的结果。

有没有一种方法可以直接解决这个问题,还是应该使用for循环遍历(9,5)?

1 个答案:

答案 0 :(得分:0)

您有带有额外尺寸的额外轴:

import numpy as np
a=np.array([[[615, 142, 18]],

   [[12, 412, 98]],

   [[5, 11, 162]],

   [[19, 76, 7]],

   [[56, 78, 12]],

   [[119, 41, 55]],

   [[26, 7, 15]],

   [[19, 67, 53]]])

a.argmax(axis=2)

输出:

Out[82]: 
array([[0],
       [1],
       [2],
       [1],
       [1],
       [0],
       [0],
       [1]], dtype=int64)