在numpy中,如果所有行在二维数组中相等,是否有一种很好的惯用方法?
我可以做类似
的事情np.all([np.array_equal(M[0], M[i]) for i in xrange(1,len(M))])
这似乎将python列表与numpy数组混合在一起,这些数组很难看,而且可能也很慢。
有更好/更整洁的方式吗?
答案 0 :(得分:17)
一种方法是检查数组arr
的每一行是否等于第一行arr[0]
:
(arr == arr[0]).all()
对于整数值,使用相等==
是正常的,但如果arr
包含浮点值,则可以使用np.isclose
来检查给定容差内的相等性:
np.isclose(a, a[0]).all()
如果您的数组包含NaN
并且您想避免棘手的NaN != NaN
问题,则可以将此方法与np.isnan
结合使用:
(np.isclose(a, a[0]) | np.isnan(a)).all()
答案 1 :(得分:5)
只需检查数组中唯一项目的编号是否为1:
>>> arr = np.array([[1]*10 for _ in xrange(5)])
>>> len(np.unique(arr)) == 1
True
击> <击> 撞击>
灵感来自unutbu&#39; answer的解决方案:
>>> arr = np.array([[1]*10 for _ in xrange(5)])
>>> np.all(np.all(arr == arr[0,:], axis = 1))
True
您的代码存在的一个问题是,在对其应用np.all()
之前,您首先要创建整个列表。由于您的版本中没有发生短路,而不是使用带有生成器表达式的Python all()
会更好:
时间比较:
>>> M = arr = np.array([[3]*100] + [[2]*100 for _ in xrange(1000)])
>>> %timeit np.all(np.all(arr == arr[0,:], axis = 1))
1000 loops, best of 3: 272 µs per loop
>>> %timeit (np.diff(M, axis=0) == 0).all()
1000 loops, best of 3: 596 µs per loop
>>> %timeit np.all([np.array_equal(M[0], M[i]) for i in xrange(1,len(M))])
100 loops, best of 3: 10.6 ms per loop
>>> %timeit all(np.array_equal(M[0], M[i]) for i in xrange(1,len(M)))
100000 loops, best of 3: 11.3 µs per loop
>>> M = arr = np.array([[2]*100 for _ in xrange(1000)])
>>> %timeit np.all(np.all(arr == arr[0,:], axis = 1))
1000 loops, best of 3: 330 µs per loop
>>> %timeit (np.diff(M, axis=0) == 0).all()
1000 loops, best of 3: 594 µs per loop
>>> %timeit np.all([np.array_equal(M[0], M[i]) for i in xrange(1,len(M))])
100 loops, best of 3: 9.51 ms per loop
>>> %timeit all(np.array_equal(M[0], M[i]) for i in xrange(1,len(M)))
100 loops, best of 3: 9.44 ms per loop
答案 2 :(得分:3)
值得一提的是above version不适用于多维数组。
例如:对于三维方形图像张量img
[256,256,3],我们需要检查图像中是否有相同的RGB [256,256]层。
在这种情况下,我们需要使用broadcasting
(img == img[:, :, 0, np.newaxis]).all()
因为简单img[:, :, 0]
给我们[256,256],但我们需要[256,256,1]来通过图层进行广播。