使用numpy.argpartition忽略NaN

时间:2019-05-30 12:05:17

标签: python python-3.x numpy sorting numpy-ndarray

我有一个大约有4900万个项目(7000 * 7000)的大型阵列,在这里我需要找到最大的N个项目及其索引,而忽略所有NaN。我无法事先删除这些NaN,因为我需要从第一个数组中提取最大N个项的索引值,以从另一个数组中提取数据,这些NaN的索引与第一个数组相比有所不同。我尝试过

np.argpartition(first_array, -N)[-N:]

这对于没有NaN的数组非常有用,但是如果存在NaN,则nan将成为最大项,因为它在python中被视为无穷大。

x = np.array([np.nan, 2, -1, 2, -4, -8, -9, 6, -3]).reshape(3, 3)
y = np.argpartition(x.ravel() , -3)[-3:]
z = x.ravel()[y]
# this is the result I am getting  === [2, 6, nan]
# but I need this ==== [2, 2, 6]

1 个答案:

答案 0 :(得分:0)

使用NaN的数量来抵消,从而计算索引并提取值-

In [200]: N = 3

In [201]: c = np.isnan(x).sum()

In [204]: idx = np.argpartition(x.ravel() , -N-c)[-N-c:-c]

In [207]: val = x.flat[idx]

In [208]: idx,val
Out[208]: (array([1, 3, 7]), array([2., 2., 6.]))