如何才能获得与测试数组的所有元素匹配的数组元素?例如,如果我有:
>>> import numpy as np
>>> arr = np.array([[0, 0, 1], [1, 0, 1], [1, 0, 1]])
>>> arr == [0,0,1]
array([[ True, True, True],
[False, True, True],
[False, True, True]], dtype=bool)
arr == [0,0,1]
的解决方案是索引0
答案 0 :(得分:2)
您需要使用axis
参数逐行检查条件:
>>> (arr == np.array([0,0,1])).all(axis=1)
array([ True, False, False], dtype=bool)
如果你想索引:
>>> np.where((arr == np.array([0,0,1])).all(axis=1))
(array([0]),)
我们还可以做一些整洁(和快速)的事情,以防止从np.reduce
调用np.all
:
>>> b = np.array([0,0,1])
>>> dt = np.dtype((np.void, arr.dtype.itemsize * arr.shape[1]))
>>> (arr.view(dt) == b.view(dt)).reshape(-1)
array([ True, False, False], dtype=bool)
一些时间:
arr = np.random.randint(0,2,(1E2,3))
%timeit (arr.view(dt) == b.view(dt)).reshape(-1)
100000 loops, best of 3: 7.76 µs per loop
%timeit (arr == b).all(axis=1)
100000 loops, best of 3: 13.5 µs per loop
使用更大的数组:
arr = np.random.randint(0,2,(1E5,3))
%timeit (arr.view(dt) == b.view(dt)).reshape(-1)
1 loops, best of 3: 221 ms per loop
%timeit (arr == b)).all(axis=1)
1 loops, best of 3: 315 ms per loop
答案 1 :(得分:0)
试试这个:
print arr[np.all((arr == [0,0,1]),axis=1)]
或者:
print np.arange(arr.shape[0])[np.all((arr == [0,0,1]),axis=1)]
如果您只想要答案的索引