是否有一种Pythonic且有效的方法来检查Numpy数组是否至少包含给定行的一个实例?通过"高效"我的意思是它终止于找到第一个匹配的行而不是迭代整个数组,即使已经找到了结果。
使用Python数组,这可以用if row in array:
非常干净地完成,但是这不像我对Numpy数组所期望的那样工作,如下所示。
使用Python数组:
>>> a = [[1,2],[10,20],[100,200]]
>>> [1,2] in a
True
>>> [1,20] in a
False
但是Numpy数组给出了不同的,而且看起来很奇怪的结果。 (__contains__
的{{1}}方法似乎没有记录。)
ndarray
答案 0 :(得分:36)
您可以使用.tolist()
>>> a = np.array([[1,2],[10,20],[100,200]])
>>> [1,2] in a.tolist()
True
>>> [1,20] in a.tolist()
False
>>> [1,20] in a.tolist()
False
>>> [1,42] in a.tolist()
False
>>> [42,1] in a.tolist()
False
或使用视图:
>>> any((a[:]==[1,2]).all(1))
True
>>> any((a[:]==[1,20]).all(1))
False
或者生成numpy列表(可能非常慢):
any(([1,2] == x).all() for x in a) # stops on first occurrence
或使用numpy逻辑函数:
any(np.equal(a,[1,2]).all(1))
如果你计算时间:
import numpy as np
import time
n=300000
a=np.arange(n*3).reshape(n,3)
b=a.tolist()
t1,t2,t3=a[n//100][0],a[n//2][0],a[-10][0]
tests=[ ('early hit',[t1, t1+1, t1+2]),
('middle hit',[t2,t2+1,t2+2]),
('late hit', [t3,t3+1,t3+2]),
('miss',[0,2,0])]
fmt='\t{:20}{:.5f} seconds and is {}'
for test, tgt in tests:
print('\n{}: {} in {:,} elements:'.format(test,tgt,n))
name='view'
t1=time.time()
result=(a[...]==tgt).all(1).any()
t2=time.time()
print(fmt.format(name,t2-t1,result))
name='python list'
t1=time.time()
result = True if tgt in b else False
t2=time.time()
print(fmt.format(name,t2-t1,result))
name='gen over numpy'
t1=time.time()
result=any((tgt == x).all() for x in a)
t2=time.time()
print(fmt.format(name,t2-t1,result))
name='logic equal'
t1=time.time()
np.equal(a,tgt).all(1).any()
t2=time.time()
print(fmt.format(name,t2-t1,result))
你可以看到命中或未命中,numpy例程与搜索数组的速度相同。对于早期命中,Python in
运算符可能的速度要快得多,如果必须一直通过数组,则生成器只是坏消息。
以下是300,000 x 3元素数组的结果:
early hit: [9000, 9001, 9002] in 300,000 elements:
view 0.01002 seconds and is True
python list 0.00305 seconds and is True
gen over numpy 0.06470 seconds and is True
logic equal 0.00909 seconds and is True
middle hit: [450000, 450001, 450002] in 300,000 elements:
view 0.00915 seconds and is True
python list 0.15458 seconds and is True
gen over numpy 3.24386 seconds and is True
logic equal 0.00937 seconds and is True
late hit: [899970, 899971, 899972] in 300,000 elements:
view 0.00936 seconds and is True
python list 0.30604 seconds and is True
gen over numpy 6.47660 seconds and is True
logic equal 0.00965 seconds and is True
miss: [0, 2, 0] in 300,000 elements:
view 0.00936 seconds and is False
python list 0.01287 seconds and is False
gen over numpy 6.49190 seconds and is False
logic equal 0.00965 seconds and is False
对于3,000,000 x 3阵列:
early hit: [90000, 90001, 90002] in 3,000,000 elements:
view 0.10128 seconds and is True
python list 0.02982 seconds and is True
gen over numpy 0.66057 seconds and is True
logic equal 0.09128 seconds and is True
middle hit: [4500000, 4500001, 4500002] in 3,000,000 elements:
view 0.09331 seconds and is True
python list 1.48180 seconds and is True
gen over numpy 32.69874 seconds and is True
logic equal 0.09438 seconds and is True
late hit: [8999970, 8999971, 8999972] in 3,000,000 elements:
view 0.09868 seconds and is True
python list 3.01236 seconds and is True
gen over numpy 65.15087 seconds and is True
logic equal 0.09591 seconds and is True
miss: [0, 2, 0] in 3,000,000 elements:
view 0.09588 seconds and is False
python list 0.12904 seconds and is False
gen over numpy 64.46789 seconds and is False
logic equal 0.09671 seconds and is False
这似乎表明np.equal
是执行此操作的最快的纯粹方式......
答案 1 :(得分:18)
Numpys __contains__
is, at the time of writing this, (a == b).any()
如果b
是一个标量(它有点毛茸茸,但我相信 - 只在1.7或更高版本中这样做 - 这可能是唯一正确的 - 这个将是正确的通用方法(a == b).all(np.arange(a.ndim - b.ndim, a.ndim)).any()
,这对于a
和b
维度的所有组合都有意义... ...
编辑:为了清楚起见,当涉及广播时,这 不一定是预期的结果。也有人可能会争辩说它应该a
分别处理np.in1d
中的项目。我不确定它应该有一个明确的方法。
现在你想要numpy在找到第一个匹配项时停止。此AFAIK目前不存在。这很难,因为numpy主要基于ufuncs,它在整个数组中做同样的事情。
Numpy确实优化了这种减少,但实际上只有在减少的数组已经是一个布尔数组(即np.ones(10, dtype=bool).any()
)时才有效。
否则它将需要__contains__
的特殊功能,该功能不存在。这可能看起来很奇怪,但你必须记住numpy支持许多数据类型,并且有更大的机制来选择正确的数据并选择正确的函数来处理它。换句话说,ufunc机制无法做到这一点,并且由于数据类型的原因,实现__contains__
或其他特殊实际上并不是那么微不足道。
你当然可以在python中编写它,或者因为你可能知道你的数据类型,所以在Cython / C中自己编写它非常简单。
那就是说。对于这些事情,使用基于排序的方法通常会好得多。这有点单调乏味,searchsorted
没有lexsort
这样的事情,但它有效(如果你愿意,你也可以滥用scipy.spatial.cKDTree
)。这假设您只想沿最后一个轴进行比较:
# Unfortunatly you need to use structured arrays:
sorted = np.ascontiguousarray(a).view([('', a.dtype)] * a.shape[-1]).ravel()
# Actually at this point, you can also use np.in1d, if you already have many b
# then that is even better.
sorted.sort()
b_comp = np.ascontiguousarray(b).view(sorted.dtype)
ind = sorted.searchsorted(b_comp)
result = sorted[ind] == b_comp
这也适用于数组b
,如果你保持排序的数组,如果你同时在b
中为单个值(行)执行它也会好得多,当时a
保持不变(否则我只会np.in1d
将其视为重新排列)。 重要提示:为安全起见,您必须执行np.ascontiguousarray
。它通常什么都不做,但如果确实如此,那将是一个很大的潜在错误。
答案 2 :(得分:8)
我认为
equal([1,2], a).all(axis=1) # also, ([1,2]==a).all(axis=1)
# array([ True, False, False], dtype=bool)
将列出匹配的行。正如Jamie所指出的,要知道是否存在至少一个这样的行,请使用any
:
equal([1,2], a).all(axis=1).any()
# True
除此之外:
我怀疑in
(和__contains__
)与上面一样,但使用的是any
而不是all
。
答案 3 :(得分:1)
如果你真的想在第一次出现时停下来,你可以写一个循环,比如:
import numpy as np
needle = np.array([10, 20])
haystack = np.array([[1,2],[10,20],[100,200]])
found = False
for row in haystack:
if np.all(row == needle):
found = True
break
print("Found: ", found)
但是,我强烈怀疑,它会比使用numpy例程为整个数组执行它的其他建议慢得多。
答案 4 :(得分:1)
我将建议的解决方案与 perfplot 进行了比较,发现如果您要在长长的未排序列表中寻找 2 元组,
np.any(np.all(a == b, axis=1))
是最快的解决方案。如果在前几行中找到匹配项,则显式短路循环总是更快。
重现情节的代码:
import numpy as np
import perfplot
target = [6, 23]
def setup(n):
return np.random.randint(0, 100, (n, 2))
def any_all(data):
return np.any(np.all(target == data, axis=1))
def tolist(data):
return target in data.tolist()
def loop(data):
for row in data:
if np.all(row == target):
return True
return False
def searchsorted(a):
s = np.ascontiguousarray(a).view([('', a.dtype)] * a.shape[-1]).ravel()
s.sort()
t = np.ascontiguousarray(target).view(s.dtype)
ind = s.searchsorted(t)
return (s[ind] == t)[0]
perfplot.save(
"out02.png",
setup=setup,
kernels=[any_all, tolist, loop, searchsorted],
n_range=[2 ** k for k in range(2, 20)],
xlabel="len(array)",
)