我的代码:
import numpy as np
N = 2
a = np.array([[0.5, 0.3, 0.2],
[0.2, 0.6, 0.2],
[0.3, 0.2, 0.7],
[np.nan, 0.2, 0.8],
[np.nan, np.nan, 0.8]
])
ind = np.argsort(np.where(np.isnan(a), -1, a), axis=1)[:, -N:]
a
Out[2]:
array([[ 0.5, 0.3, 0.2],
[ 0.2, 0.6, 0.2],
[ 0.3, 0.2, 0.7],
[ nan, 0.2, 0.8],
[ nan, nan, 0.8]])
ind
Out[3]:
array([[1, 0],
[2, 1],
[0, 2],
[1, 2],
[1, 2]], dtype=int64)
ind [:,1]是最高的,而ind [:,0]是第二高
除了最后一行中有2个nans的情况外,这很好。 如果是纳米,如何忽略第二高的值? 期望的输出将是:
array([[1, 0],
[2, 1],
[0, 2],
[1, 2],
[nan, 2]], dtype=int64)
奖金问题:在[1,:]的情况下如何随机打破平局?
答案 0 :(得分:1)
Advanced-index
并检查NaNs
是否为我们提供了一个掩码,然后可以与np.where
一起使用进行选择,就像这样 -
In [244]: a_ind = a[np.arange(ind.shape[0])[:,None],ind]
In [245]: mask = np.isnan(a_ind)
In [246]: np.where(mask, np.nan, ind)
Out[246]:
array([[ 1., 0.],
[ 2., 1.],
[ 0., 2.],
[ 1., 2.],
[ nan, 2.]])
请注意,具有NaN
的数组将转换为float
dtype,因此最终输出也将为float
dtype。