在二维numpy数组的每一行中找到N个最大值

时间:2020-05-18 08:32:18

标签: python numpy

我有一个二维的numpy数组a

a = np.array(range(0,25)).reshape(5,5)
---
[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]
 [15 16 17 18 19]
 [20 21 22 23 24]]

我想找到每行的最大N个值并将其替换为100。我这样做很慢:

N = 2
idx = a.argsort()
for i in range(a.shape[0]):
    a[i,idx[i][::-1][0:N]] = 100
print(a)
---
[[  0   1   2 100 100]
 [  5   6   7 100 100]
 [ 10  11  12 100 100]
 [ 15  16  17 100 100]
 [ 20  21  22 100 100]]

实际上我矩阵的形状是6000 * 6000。如何以更好的方式做到这一点?喜欢申请吗?

1 个答案:

答案 0 :(得分:1)

您可以在此处使用argpartition

N=2
ix = a.argpartition(-N)[:,-N:]
a[np.arange(a.shape[0])[:,None], ix] = 100

print(a)
array([[  0,   1,   2, 100, 100],
       [  5,   6,   7, 100, 100],
       [ 10,  11,  12, 100, 100],
       [ 15,  16,  17, 100, 100],
       [ 20,  21,  22, 100, 100]])