numpy中的另一个数组过滤数组元素

时间:2018-12-26 06:14:22

标签: python arrays numpy

这是一个简单的示例

import numpy as np
x=np.random.rand(5,5)
k,p = np.where(x>0.5)

k和p是索引数组

现在我有一个应视为m = [0,2,4]的行的列表,因此我需要找到列表m中的k的所有条目。

我想出了一个非常简单但可怕的低效解决方案

d = np.array([ (a,b) for a,b in zip(k,p) if a in m])

该解决方案有效,但是非常慢。我正在寻找更好,更高效的产品。我需要使用动态调整的m进行数百万次此类操作,因此算法的效率确实是一个关键问题。

3 个答案:

答案 0 :(得分:2)

也许以下速度更快:

d=np.dstack((k,p))[0]
print(d[np.isin(d[:,0],m)])

答案 1 :(得分:1)

您可以使用isin()来获取布尔掩码,该布尔掩码可以用于索引k

>>> x=np.random.rand(3,3)
>>> x
array([[0.74043564, 0.48328081, 0.82396324],
       [0.40693944, 0.24951958, 0.18043229],
       [0.46623863, 0.53559775, 0.98956277]])
>>> k, p = np.where(x > 0.5)
>>> p
array([0, 2, 1, 2])
>>> k
array([0, 0, 2, 2])
>>> m
array([0, 1])  
>>> np.isin(k, m)
array([ True,  True, False, False])
>>> k[np.isin(k, m)]
array([0, 0])

答案 2 :(得分:0)

怎么样:

import numpy as np
m = np.array([0, 2, 4])
k, p = np.where(x[m, :] > 0.5)
k = m[k]
print(zip(k, p))

这仅考虑有趣的行(然后将它们压缩为2d索引)。