我想比较考虑位置的两个numpy数组一元素。例如
[1, 2, 3]==[1, 2, 3] -> True
[1, 2, 3]==[2, 1, 3] -> False
我尝试了以下
for index in range(list1.shape[0]):
if list1[index] != list2[index]:
return False
return True
但我收到以下错误
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
但是,以下内容并非正确使用.any或.all
numpy.any(numpy.array([1,2,3]), numpy.array([1,2,3]))
numpy.all(numpy.array([1,2,3]), numpy.array([1,2,3]))
返回时
TypeError: only length-1 arrays can be converted to Python scalars
我很困惑,有人可以解释我做错了吗
由于
答案 0 :(得分:3)
您还可以使用array_equal
:
In [11]: a = np.array([1, 2, 3])
In [12]: b = np.array([2, 1, 3])
In [13]: np.array_equal(a, a)
Out[13]: True
In [14]: np.array_equal(a, b)
Out[14]: False
这应该更高效,因为您不需要保留临时a==b
...
注意:对于要使用np.all
而不是all
的较大数组,请稍微考虑一下性能。 array_equal
执行大约相同的,除非数组提前不同,然后它会更快,因为它可以提前失败:
In [21]: a = np.arange(100000)
In [22]: b = np.arange(100000)
In [23]: c = np.arange(1, 100000)
In [24]: %timeit np.array_equal(a, a) # Note: I expected this to check is first, it doesn't
10000 loops, best of 3: 183 µs per loop
In [25]: %timeit np.array_equal(a, b)
10000 loops, best of 3: 189 µs per loop
In [26]: %timeit np.array_equal(a, c)
100000 loops, best of 3: 5.9 µs per loop
In [27]: %timeit np.all(a == b)
10000 loops, best of 3: 184 µs per loop
In [28]: %timeit np.all(a == c)
10000 loops, best of 3: 40.7 µs per loop
In [29]: %timeit all(a == b)
100 loops, best of 3: 3.69 ms per loop
In [30]: %timeit all(a == c) # ahem!
# TypeError: 'bool' object is not iterable
答案 1 :(得分:2)
您可以将一组布尔值传递给all
,例如:
>>> import numpy as np
>>> a = np.array([1, 2, 3])
>>> b = np.array([2, 1, 3])
>>> a == b
array([False, False, True], dtype=bool)
>>> np.all(a==b) # also works with all for 1D arrays
False
请注意,对于小型数组,内置all
比np.all
快得多(np.array_equal
仍然较慢):
>>> timeit.timeit("all(a==b)", setup="import numpy as np; a = np.array([1, 2, 3]); b = np.array([2, 1, 3])")
0.8798369040014222
>>> timeit.timeit("np.all(a==b)", setup="import numpy as np; a = np.array([1, 2, 3]); b = np.array([2, 1, 3])")
9.980971871998918
>>> timeit.timeit("np.array_equal(a, b)", setup="import numpy as np; a = np.array([1, 2, 3]); b = np.array([2, 1, 3])")
13.838635700998566
但无法正确使用多维数组:
>>> a = np.arange(9).reshape(3, 3)
>>> b = a.copy()
>>> b[0, 0] = 42
>>> all(a==b)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
>>> np.all(a==b)
False
对于较大的数组,np.all
最快:
>>> timeit.timeit("np.all(a==b)", setup="import numpy as np; a = np.arange(1000); b = a.copy(); b[999] = 0")
13.581198551000853
>>> timeit.timeit("all(a==b)", setup="import numpy as np; a = np.arange(1000); b = a.copy(); b[999] = 0")
30.610838356002205
>>> timeit.timeit("np.array_equal(a, b)", setup="import numpy as np; a = np.arange(1000); b = a.copy(); b[999] = 0")
17.95089965599982