如何使用python中的条件过滤numpy数组

时间:2018-02-11 23:07:41

标签: python numpy

我正在使用我的numpy数组v,如下所示删除< = 1的元素,然后选择numpy数组中前3个元素的索引。

 for ele in v.toarray()[0].tolist():
        if ele <= 1:
            useless_index = v.toarray()[0].tolist().index(ele)
            temp_list.append(useless_index)

 #take top 3 words from each document
 indexes =v.toarray()[0].argsort()[-3:]
 useful_list = list(set(indexes) - set(temp_list))

但是,我使用的当前代码非常慢(因为我有数百万个numpy数组)并且需要数天才能运行。有没有有效的方法在python中做同样的事情?

1 个答案:

答案 0 :(得分:3)

v = v[v > 1]
indices = np.argpartition(v, -3)[-3:]
values = v[indices]

如上所述hereargpartitionO(n + k log k)时间内运行。在您的情况下,n = 1e6k=3