Python:在3D矩阵中沿一个轴查找最大值,将非最大值设为零

时间:2019-03-08 03:47:32

标签: python numpy

我有一个如下所示的3-d矩阵,并希望沿轴1取最大值,并将所有非最大值保持为零。

A = np.random.rand(3,3,2)

  [[[0.34444547, 0.50260393],
    [0.93374423, 0.39021899],
    [0.94485653, 0.9264881 ]],

   [[0.95446736, 0.335068  ],
    [0.35971558, 0.11732342],
    [0.72065402, 0.36436023]],

   [[0.56911013, 0.04456443],
    [0.17239996, 0.96278067],
    [0.26004909, 0.06767436]]]

所需结果:

   [[0         , 0         ],
    [0         , 0         ],
    [0.94485653, 0.9264881]],

   [[0.95446736, 0          ],
    [0         , 0          ],
    [0         , 0.36436023]],

   [[0.56911013, 0         ],
    [0         , 0.96278067],
    [0         , 0         ]]])

我尝试过:

B = np.zeros_like(A)  #return matrix of zero with same shape as A

max_idx = np.argmax(A, axis=1) #index along axis 1 with max value

    array([[2, 0],
           [2, 2],
           [0, 2],
           [0, 1]])

C = np.max(A, axis=1, keepdims = True) #gives a (4,1,2) matrix of max value along axis 1

    array([[[0.95377958, 0.92940525]],
           [[0.94485653, 0.9264881 ]],
           [[0.95446736, 0.36436023]],
           [[0.56911013, 0.96278067]]])

但是我不知道如何将这些想法结合在一起以获得我想要的输出。请帮忙!

1 个答案:

答案 0 :(得分:1)

您可以从max_idx中获取最大值的3 dimensional indexmax_idx中的值是沿着最大值的轴1的索引。由于您的其他轴分别是3和2(3 x 2 = 6),因此有六个值。您只需要了解numpy通过它们的顺序即可获得其他每个轴的索引。您首先遍历最后一个轴:

d0, d1, d2 = A.shape
a0 = [i for i in range(d0) for _ in range(d2)]   # [0, 0, 1, 1, 2, 2]
a1 = max_idx.flatten()                           # [2, 2, 0, 2, 0, 1]
a2 = [k for _ in range(d0) for k in range(d2)]   # [0, 1, 0, 1, 0, 1]

B[a0, a1, a2] = A[a0, a1, a2]

输出:

array([[[0.        , 0.        ],
        [0.        , 0.        ],
        [0.94485653, 0.9264881 ]],

       [[0.95446736, 0.        ],
        [0.        , 0.        ],
        [0.        , 0.36436023]],

       [[0.56911013, 0.        ],
        [0.        , 0.96278067],
        [0.        , 0.        ]]])