如何找到与某个列表匹配的numpy二维数组中的所有元素?

时间:2016-01-22 11:09:17

标签: python arrays performance numpy vectorization

我有一个二维NumPy数组,例如:

array([[1, 1, 0, 2, 2],
       [1, 1, 0, 2, 0],
       [0, 0, 0, 0, 0],
       [3, 3, 0, 4, 4],
       [3, 3, 0, 4, 4]])

我想从该数组中获取某些列表中的所有元素,例如(1,3,4)。示例中的期望结果是:

array([[1, 1, 0, 0, 0],
       [1, 1, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [3, 3, 0, 4, 4],
       [3, 3, 0, 4, 4]])

我知道我可以这样做(按照Numpy: find elements within range的建议):

np.logical_or(
    np.logical_or(cc_labeled == 1, cc_labeled == 3),
    cc_labeled == 4
)

,但这只会在示例中合理有效。实际上迭代地使用for循环和numpy.logical_or结果非常慢,因为可能的值列表是数千(而numpy数组的大小大约是1000 x 1000)。

2 个答案:

答案 0 :(得分:3)

您可以使用np.in1d -

A*np.in1d(A,[1,3,4]).reshape(A.shape)

此外,np.where可以使用 -

np.where(np.in1d(A,[1,3,4]).reshape(A.shape),A,0)

您还可以使用np.searchsorted来查找此类匹配,方法是使用其可选的'side'参数,输入为leftright,并注意匹配时,searchsorted将用这两个输入输出不同的结果。因此,相当于np.in1d(A,[1,3,4])将是 -

M = np.searchsorted([1,3,4],A.ravel(),'left') != \
    np.searchsorted([1,3,4],A.ravel(),'right')

因此,最终输出将是 -

out = A*M.reshape(A.shape)

请注意,如果未对输入搜索列表进行排序,则需要在sorter中使用带有argsort索引的可选参数np.searchsorted

示例运行 -

In [321]: A
Out[321]: 
array([[1, 1, 0, 2, 2],
       [1, 1, 0, 2, 0],
       [0, 0, 0, 0, 0],
       [3, 3, 0, 4, 4],
       [3, 3, 0, 4, 4]])

In [322]: A*np.in1d(A,[1,3,4]).reshape(A.shape)
Out[322]: 
array([[1, 1, 0, 0, 0],
       [1, 1, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [3, 3, 0, 4, 4],
       [3, 3, 0, 4, 4]])

In [323]: np.where(np.in1d(A,[1,3,4]).reshape(A.shape),A,0)
Out[323]: 
array([[1, 1, 0, 0, 0],
       [1, 1, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [3, 3, 0, 4, 4],
       [3, 3, 0, 4, 4]])

In [324]: M = np.searchsorted([1,3,4],A.ravel(),'left') != \
     ...:     np.searchsorted([1,3,4],A.ravel(),'right')
     ...: A*M.reshape(A.shape)
     ...: 
Out[324]: 
array([[1, 1, 0, 0, 0],
       [1, 1, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [3, 3, 0, 4, 4],
       [3, 3, 0, 4, 4]])

运行时测试并验证输出 -

In [309]: # Inputs
     ...: A = np.random.randint(0,1000,(400,500))
     ...: lst = np.sort(np.random.randint(0,1000,(100))).tolist()
     ...: 
     ...: def func1(A,lst):                         
     ...:   return A*np.in1d(A,lst).reshape(A.shape)
     ...: 
     ...: def func2(A,lst):                         
     ...:   return np.where(np.in1d(A,lst).reshape(A.shape),A,0)
     ...: 
     ...: def func3(A,lst):                         
     ...:   mask = np.searchsorted(lst,A.ravel(),'left') != \
     ...:          np.searchsorted(lst,A.ravel(),'right')
     ...:   return A*mask.reshape(A.shape)
     ...: 

In [310]: np.allclose(func1(A,lst),func2(A,lst))
Out[310]: True

In [311]: np.allclose(func1(A,lst),func3(A,lst))
Out[311]: True

In [312]: %timeit func1(A,lst)
10 loops, best of 3: 30.9 ms per loop

In [313]: %timeit func2(A,lst)
10 loops, best of 3: 30.9 ms per loop

In [314]: %timeit func3(A,lst)
10 loops, best of 3: 28.6 ms per loop

答案 1 :(得分:3)

使用np.in1d

np.in1d(arr, [1,3,4]).reshape(arr.shape)

in1d,顾名思义,在扁平阵列上运行,因此您需要在操作后重新塑造。