在numpy数组中有效地查找numpy数组中的值

时间:2017-05-23 20:36:45

标签: python numpy

我正在尝试创建仅包含特定值的numpy数组的副本。这是我正在使用的代码:

A = np.array([[1,2,3],[4,5,6],[7,8,9]])
query_val = 5
B = (A == query_val) * np.array(query_val, dtype=np.uint16)

......这正是我想要的。

现在,我希望query_val不仅仅是一个值。这里的答案是:Numpy where function multiple conditions建议使用逻辑和操作,但由于你多次使用==,创建多个中间结果,因此空间效率很低。

在我的情况下,这意味着我没有足够的RAM来做到这一点。有没有办法在原生numpy中以最小的空间开销正确地做到这一点?

1 个答案:

答案 0 :(得分:0)

这是使用np.searchsorted -

的一种方法
def mask_in(a, b):
    idx = np.searchsorted(b,a)
    idx[idx==b.size] = 0
    return np.where(b[idx]==a, a,0)

示例运行 -

In [356]: a
Out[356]: 
array([[5, 1, 4],
       [4, 5, 6],
       [2, 4, 9]])

In [357]: b
Out[357]: array([2, 4, 5])

In [358]: mask_in(a,b)
Out[358]: 
array([[5, 0, 4],
       [4, 5, 0],
       [2, 4, 0]])