替换numpy数组的所有值,这些值小于每行中“n”个最大项

时间:2017-08-22 11:12:50

标签: python numpy multidimensional-array

我有2d numpy数组大小~70k * 10k。我想用零替换所有小于每行中“N”个最大元素的值。例如:

arr = np.array([[1, 0, 6, 5, 2, 5], 
                [7, 5, 2, 6, 7, 3], 
                [3, 5, 1, 5, 6, 4]])

对于N = 3,结果应为:

result = np.array([[0, 0, 6, 5, 0, 5], # 3 largest in row: 6, 5, 5
                   [7, 0, 0, 6, 7, 0], 
                   [0, 5, 0, 5, 6, 0]])

未替换的数字位置和数组的形状应保持不变。

1 个答案:

答案 0 :(得分:4)

您可以使用np.partition找到N - 最大值,然后使用布尔索引来替换" s"下面的所有内容"它的行中的值:

import numpy as np
arr = np.array([[1, 0, 6, 5, 2, 5], 
                [7, 5, 2, 6, 7, 3], 
                [3, 5, 1, 5, 6, 4]])

N = 3
nlargest = np.partition(arr, -N, axis=1)[:, -N]
arr[arr < nlargest[:, None]] = 0
arr
# array([[0, 0, 6, 5, 0, 5],
#        [7, 0, 0, 6, 7, 0],
#        [0, 5, 0, 5, 6, 0]])