如何有效地过滤每行矩阵的最大元素

时间:2019-01-07 16:48:05

标签: python arrays numpy max

给出2D数组,我正在寻找一种pythonic的方法来获得形状相同的数组,每行只有最大的元素。 请参见下面的max_row_filter函数

def max_row_filter(mat2d):
    m = np.zeros(mat2d.shape)
    for r in range(mat2d.shape[0]):
        c = np.argmax(mat2d[r])
        m[r,c]=mat2d[r,c]
    return m

p = np.array([[1,2,3],[5,4,3,],[9,10,3]])
max_row_filter(p)

Out: array([[ 0.,  0.,  3.],
            [ 5.,  0.,  0.],
            [ 0., 10.,  0.]])

我正在寻找一种有效的方法来执行此操作,适合在大型阵列上完成。

2 个答案:

答案 0 :(得分:1)

替代答案(这将保留重复项):

p * (p==p.max(axis=1, keepdims=True))

答案 1 :(得分:0)

如果没有重复项,则可以使用numpy.argmax

import numpy as np

p = np.array([[1, 2, 3],
              [5, 4, 3, ],
              [9, 10, 3]])

result = np.zeros_like(p)

rows, cols = zip(*enumerate(np.argmax(p, axis=1)))
result[rows, cols] = p[rows, cols]

print(result)

输出

[[ 0  0  3]
 [ 5  0  0]
 [ 0 10  0]]

请注意,对于多次出现,argmax返回第一次出现。