我有一个numpy数组列表,想检查所有数组是否相等。这样做最快的方法是什么?
我知道numpy.array_equal函数(https://docs.scipy.org/doc/numpy-1.10.0/reference/generated/numpy.array_equal.html),但据我所知,这只适用于两个数组,我想检查N个数组。
我还找到了这个答案来测试列表中的所有元素:check if all elements in a list are identical。 但是,当我在接受的答案中尝试每个方法时,我得到一个异常(ValueError:具有多个元素的数组的真值是不明确的。使用a.any()或a.all())
谢谢,
答案 0 :(得分:3)
您可以简单地adapt a general iterator method进行数组比较
def all_equal(iterator):
try:
iterator = iter(iterator)
first = next(iterator)
return all(np.array_equal(first, rest) for rest in iterator)
except StopIteration:
return True
如果这不起作用,则表示您的阵列不相等。
演示:
>>> i = [np.array([1,2,3]),np.array([1,2,3]),np.array([1,2,3])]
>>> print(all_equal(i))
True
>>> j = [np.array([1,2,4]),np.array([1,2,3]),np.array([1,2,3])]
>>> print(all_equal(j))
False
答案 1 :(得分:0)
我猜你可以使用这个独特的功能。
http://docs.scipy.org/doc/numpy-1.10.1/reference/generated/numpy.unique.html#numpy.unique
如果数组中的所有子数组都相同,则只返回一个项目。
这里有更好的描述如何使用它。
答案 2 :(得分:0)
如果您的数组大小相同,那么使用numpy_indexed(免责声明:我是其作者)的此解决方案应该可以正常工作并且非常高效:
import numpy_indexed as npi
npi.all_unique(list_of_arrays)