从以下数组开始
array([ nan, nan, nan, 1., nan, nan, 0., nan, nan])
生成如下:
import numpy as np
row = np.array([ np.nan, np.nan, np.nan, 1., np.nan, np.nan, 0., np.nan, np.nan])
我想获取已排序数组的索引,然后排除nans
。在这种情况下,我想获得[6,3]
。
我想出了以下方法来做到这一点:
vals = np.sort(row)
inds = np.argsort(row)
def select_index_by_value(indices, values):
selected_indices = []
for i in range(len(indices)):
if not np.isnan(values[i]):
selected_indices.append(indices[i])
return selected_indices
selected_inds = select_index_by_value(inds, vals)
现在selected_inds
是[6,3]
。但是,这似乎有很多行代码可以实现简单的操作。这可能是一种更短的方式吗?
答案 0 :(得分:3)
你可以这样做 -
# Store non-NaN indices
idx = np.where(~np.isnan(row))[0]
# Select non-NaN elements, perform argsort and use those argsort
# indices to re-order non-NaN indices as final output
out = idx[row[idx].argsort()]
答案 1 :(得分:1)
另一种选择:
row.argsort()[~np.isnan(np.sort(row))]
# array([6, 3])
答案 2 :(得分:0)
还有另一种更快的解决方案(对于OP数据)。
Psidom的解决方案
%timeit row.argsort()[~np.isnan(np.sort(row))]
The slowest run took 31.23 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 8.16 µs per loop
%timeit idx = np.where(~np.isnan(row))[0]; idx[row[idx].argsort()]
The slowest run took 35.11 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 4.73 µs per loop
基于Divakar的解决方案
%timeit np.where(~np.isnan(row))[0][::-1]
The slowest run took 9.42 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 2.86 µs per loop
我认为这是有效的,因为np.where(~np.isnan(row))
保留了订单。