假设我有一个2D numpy数组。给定n,我希望对矩阵中除前n个元素之外的所有元素求n。
我尝试过idx = (-y_pred).argsort(axis=-1)[:, :n]
来确定最大n值的索引是什么,但是idx
的形状是[H,W,n],我不明白为什么。
我已经尝试过-
sorted_list = sorted(y_pred, key=lambda x: x[0], reverse=True)
top_ten = sorted_list[:10]
但是它并没有真正返回前10个索引。
有没有一种有效的方法来找到前n个索引,其余的归零?
编辑 输入是值的NxM矩阵,输出是大小NxM的相同矩阵,因此除对应于前10个值的索引外,所有值均为0
答案 0 :(得分:1)
这是基于numpy.argpartition()
的想法使用How do I get indices of N maximum values in a NumPy array?的一种方法
# sample input to work with
In [62]: arr = np.random.randint(0, 30, 36).reshape(6, 6)
In [63]: arr
Out[63]:
array([[ 8, 25, 12, 26, 21, 29],
[24, 22, 7, 14, 23, 13],
[ 1, 22, 18, 20, 10, 19],
[26, 10, 27, 19, 6, 28],
[17, 28, 9, 13, 11, 12],
[18, 25, 15, 29, 25, 25]])
# initialize an array filled with zeros
In [59]: nullified_arr = np.zeros_like(arr)
In [64]: top_n = 10
# get top_n indices of `arr`
In [57]: top_n_idxs = np.argpartition(arr.reshape(-1), -top_n)[-top_n:]
# copy `top_n` values to output array
In [60]: nullified_arr.reshape(-1)[top_n_idxs] = arr.reshape(-1)[top_n_idxs]
In [71]: nullified_arr
Out[71]:
array([[ 0, 25, 0, 26, 0, 29],
[ 0, 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0, 0],
[26, 0, 27, 0, 0, 28],
[ 0, 28, 0, 0, 0, 0],
[ 0, 0, 0, 29, 25, 25]])
答案 1 :(得分:1)
以下代码将使NxM
矩阵X
无效。
threshold = np.sort(X.ravel())[-n] # get the nth largest value
idx = X < threshold
X[idx] = 0
注意:当值重复时,此方法可以返回一个包含多于n
个非零元素的矩阵。