假设我有一堆数组,包括x
和y
,我想检查它们是否相等。一般来说,我可以使用np.all(x == y)
(除非我现在忽略了一些愚蠢的角落案例)。
但是,这会评估(x == y)
的整个数组,这通常是不需要的。我的数组非常大,我有很多,两个数组相等的概率很小,所以很可能,我真的只需要在{(x == y)
之前评估all
的一小部分。 1}}函数可以返回False,所以这对我来说不是最佳解决方案。
我已尝试将内置all
功能与itertools.izip
结合使用:all(val1==val2 for val1,val2 in itertools.izip(x, y))
然而,在两个数组 相等的情况下,这似乎要慢得多,总的来说,它不值得在np.all
上使用。我认为是因为内置all
的通用性。 np.all
并不适用于生成器。
有没有办法以更快的方式做我想要的事情?
我知道这个问题类似于之前提出的问题(例如Comparing two numpy arrays for equality, element-wise),但他们特别不提及提前终止的问题。
答案 0 :(得分:7)
在本地实现numpy之前,您可以编写自己的函数并使用numba进行jit编译:
import numpy as np
import numba as nb
@nb.jit(nopython=True)
def arrays_equal(a, b):
if a.shape != b.shape:
return False
for ai, bi in zip(a.flat, b.flat):
if ai != bi:
return False
return True
a = np.random.rand(10, 20, 30)
b = np.random.rand(10, 20, 30)
%timeit np.all(a==b) # 100000 loops, best of 3: 9.82 µs per loop
%timeit arrays_equal(a, a) # 100000 loops, best of 3: 9.89 µs per loop
%timeit arrays_equal(a, b) # 100000 loops, best of 3: 691 ns per loop
最差案例性能(数组相等)相当于np.all
,如果提前停止,编译函数有可能大大超过np.all
。
答案 1 :(得分:1)
在numpy page on github上显然正在讨论为阵列比较添加短路逻辑,因此可能会在未来的numpy版本中提供。
答案 2 :(得分:0)
您可以迭代数组的所有元素并检查它们是否相等。 如果数组很可能不相等,则返回的速度比.all函数快得多。 像这样:
<script src="https://ajax.googleapis.com/ajax/libs/jquery/2.1.1/jquery.min.js"></script>
<input>
答案 3 :(得分:0)
理解基础数据结构的人可能会对此进行优化或解释它是否可靠/安全/良好实践,但它似乎有效。
np.all(a==b)
Out[]: True
memoryview(a.data)==memoryview(b.data)
Out[]: True
%timeit np.all(a==b)
The slowest run took 10.82 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 6.2 µs per loop
%timeit memoryview(a.data)==memoryview(b.data)
The slowest run took 8.55 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 1.85 µs per loop
如果我理解正确,ndarray.data
会创建一个指向数据缓冲区的指针,而memoryview
会创建一个可以从缓冲区中短路的本机python类型。
我想。
编辑:进一步的测试显示它可能没有显示出的时间改善那么大。以前a=b=np.eye(5)
a=np.random.randint(0,10,(100,100))
b=a.copy()
%timeit np.all(a==b)
The slowest run took 6.70 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 17.7 µs per loop
%timeit memoryview(a.data)==memoryview(b.data)
10000 loops, best of 3: 30.1 µs per loop
np.all(a==b)
Out[]: True
memoryview(a.data)==memoryview(b.data)
Out[]: True
答案 4 :(得分:0)
def compare(a, b):
if len(a) > 0 and not np.array_equal(a[0], b[0]):
return False
if len(a) > 15 and not np.array_equal(a[:15], b[:15]):
return False
if len(a) > 200 and not np.array_equal(a[:200], b[:200]):
return False
return np.array_equal(a, b)
:)
答案 5 :(得分:0)
嗯,不是真正的答案,因为我没有检查它是否断路,而是:
从文档中:
如果两个
array_like
对象不相等,则引发AssertionError。
Try
Except
(如果不在性能敏感的代码路径上)。
或者遵循底层的源代码,也许是有效的。
答案 6 :(得分:0)
正如ThomasKühn在对您的帖子的评论中所写的那样,array_equal
是一个可以解决问题的函数。在Numpy's API reference中有描述。