高效的numpy argsort与条件同时保持原始指数

时间:2018-01-25 00:59:24

标签: python numpy

我想知道在给定条件的情况下做一个数组的argsort最有效的方法,同时保留原始索引

import numpy as np

x = np.array([0.63, 0.5, 0.7, 0.65])

np.argsort(x)
#Corrected argsort(x) solution
Out[99]: array([1, 0, 3, 2])

我希望使用x> 0.6的条件来控制此数组。因为0.5 < 0.6,不应包括指数1.

x = np.array([0.63, 0.5, 0.7, 0.65])
index = x.argsort()
list(filter(lambda i: x[i] > 0.6, index))

[0,3,2]

这是低效的,因为它没有矢量化。

编辑: 过滤器将消除大多数元素。理想情况下,它首先进行过滤,然后进行排序,同时保留原始索引。

4 个答案:

答案 0 :(得分:7)

方法1(与Tai的方法相同,但使用整数索引)

派对太迟了,如果我的解决方案是重复已经发布的解决方案 - ping我,我会删除它。

def meth_agn_v1(x, thresh):
    idx = np.arange(x.size)[x > thresh]
    return idx[np.argsort(x[idx])]

然后,

In [143]: meth_agn_v1(x, 0.5)
Out[143]: array([0, 3, 2])

方法2(显着的性能改进)

这使用了我的答案的最后一部分(与Tai的方法比较)中表达的相同的想法,即整数索引比布尔索引更快(对于要选择的少量预期元素)并且完全避免创建初始索引。

def meth_agn_v2(x, thresh):
    idx, = np.where(x > thresh)
    return idx[np.argsort(x[idx])]

时序

In [144]: x = np.random.rand(100000)

In [145]: timeit meth_jp(x, 0.99)
100 loops, best of 3: 7.43 ms per loop

In [146]: timeit meth_alex(x, 0.99)
1000 loops, best of 3: 498 µs per loop

In [147]: timeit meth_tai(x, 0.99)
1000 loops, best of 3: 298 µs per loop

In [148]: timeit meth_agn_v1(x, 0.99)
1000 loops, best of 3: 232 µs per loop

In [161]: timeit meth_agn_v2(x, 0.99)
10000 loops, best of 3: 95 µs per loop

v1与Tai方法的比较

我答案的第一个版本与Tai的答案非常相似,但不完全相同。

最近出版的Tai的方法:

def meth_tai(x, thresh):
    y = np.arange(x.shape[0])
    y = y [x > thresh]  
    x = x [x > thresh] # x = x[y] is used in my method
    y[np.argsort(x)]

所以,我的方法与使用整数数组索引而不是Tai使用的布尔索引不同。对于少量选定的元素,整数索引比布尔索引更快,这使得这种方法比Tai的方法更有效,即使在Tai优化了他的代码之后。

答案 1 :(得分:6)

参加聚会有点晚了。我们的想法是,我们可以根据另一个数组的排序索引对数组进行排序。

y = np.arange(x.shape[0]) # y for preserving the indices
mask = x > thresh
y = y[mask]  
x = x[mask]
ans = y[np.argsort(x)]    # change order of y based on sorted indices of x

该方法是添加一个仅用于记录y索引的数组x。然后,我们根据布尔索引x > thresh过滤掉两个数组。然后,使用xargsort进行排序。最后,使用argsort返回的索引来更改y

的顺序

答案 2 :(得分:3)

方法1(@jp_data_analysis答案)

除非你有理由不这样做,否则你应该使用这个。

def meth1(x, thresh):
    return np.argsort(x)[(x <= thresh).sum():]

方法2

如果过滤器会大大减少数组中元素的数量并且数组很大,那么以下内容可能有所帮助:

def meth2(x, thresh):
    m = x > thresh
    idxs = np.argsort(x[m])
    offsets = (~m).cumsum()
    return idxs + offsets[m][idxs]

速度比较

x = np.random.rand(10000000)

%timeit meth1(x, 0.99)
# 2.81 s ± 244 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit meth2(x, 0.99)
# 104 ms ± 1.22 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

答案 3 :(得分:0)

这是另一种破解方法,它使用一些任意最大数量修改原始数组,这在原始数组中不太可能发生。

In [50]: x = np.array([0.63, 0.5, 0.7, 0.65])
In [51]: invmask = ~(x > 0.6)

# replace it with some huge number which will not occur in your original array
In [52]: x[invmask] = 9999.0

In [53]: np.argsort(x)[:-sum(invmask)]
Out[53]: array([0, 3, 2])