只获取与测试数组的所有元素匹配的数组元素?

时间:2014-02-17 20:18:05

标签: python arrays numpy

如何才能获得与测试数组的所有元素匹配的数组元素?例如,如果我有:

>>> import numpy as np

>>> arr = np.array([[0, 0, 1], [1, 0, 1], [1, 0, 1]])
>>> arr == [0,0,1]
array([[ True,  True,  True],
   [False,  True,  True],
   [False,  True,  True]], dtype=bool)

arr == [0,0,1]的解决方案是索引0

2 个答案:

答案 0 :(得分:2)

您需要使用axis参数逐行检查条件:

>>> (arr == np.array([0,0,1])).all(axis=1)
array([ True, False, False], dtype=bool)

如果你想索引:

>>> np.where((arr == np.array([0,0,1])).all(axis=1))
(array([0]),)

我们还可以做一些整洁(和快速)的事情,以防止从np.reduce调用np.all

>>> b = np.array([0,0,1])
>>> dt = np.dtype((np.void, arr.dtype.itemsize * arr.shape[1]))
>>> (arr.view(dt) == b.view(dt)).reshape(-1)
array([ True, False, False], dtype=bool)

一些时间:

arr = np.random.randint(0,2,(1E2,3))

%timeit (arr.view(dt) == b.view(dt)).reshape(-1)
100000 loops, best of 3: 7.76 µs per loop
%timeit (arr == b).all(axis=1)
100000 loops, best of 3: 13.5 µs per loop

使用更大的数组:

arr = np.random.randint(0,2,(1E5,3))

%timeit (arr.view(dt) == b.view(dt)).reshape(-1)
1 loops, best of 3: 221 ms per loop    
%timeit (arr == b)).all(axis=1)
1 loops, best of 3: 315 ms per loop

答案 1 :(得分:0)

试试这个:

print arr[np.all((arr == [0,0,1]),axis=1)]

或者:

print np.arange(arr.shape[0])[np.all((arr == [0,0,1]),axis=1)]

如果您只想要答案的索引