如何测试所有行是否在numpy中相等

时间:2014-10-02 15:02:05

标签: python arrays numpy

在numpy中,如果所有行在二维数组中相等,是否有一种很好的惯用方法?

我可以做类似

的事情
np.all([np.array_equal(M[0], M[i]) for i in xrange(1,len(M))])

这似乎将python列表与numpy数组混合在一起,这些数组很难看,而且可能也很慢。

有更好/更整洁的方式吗?

3 个答案:

答案 0 :(得分:17)

一种方法是检查数组arr的每一行是否等于第一行arr[0]

(arr == arr[0]).all()

对于整数值,使用相等==是正常的,但如果arr包含浮点值,则可以使用np.isclose来检查给定容差内的相等性:

np.isclose(a, a[0]).all()

如果您的数组包含NaN并且您想避免棘手的NaN != NaN问题,则可以将此方法与np.isnan结合使用:

(np.isclose(a, a[0]) | np.isnan(a)).all()

答案 1 :(得分:5)

只需检查数组中唯一项目的编号是否为1:

>>> arr = np.array([[1]*10 for _ in xrange(5)])
>>> len(np.unique(arr)) == 1
True

<击>

灵感来自unutbu&#39; answer的解决方案:

>>> arr = np.array([[1]*10 for _ in xrange(5)])
>>> np.all(np.all(arr == arr[0,:], axis = 1))
True

您的代码存在的一个问题是,在对其应用np.all()之前,您首先要创建整个列表。由于您的版本中没有发生短路,而不是使用带有生成器表达式的Python all()会更好:

时间比较:

>>> M = arr = np.array([[3]*100] + [[2]*100 for _ in xrange(1000)])
>>> %timeit np.all(np.all(arr == arr[0,:], axis = 1))
1000 loops, best of 3: 272 µs per loop
>>> %timeit (np.diff(M, axis=0) == 0).all()
1000 loops, best of 3: 596 µs per loop
>>> %timeit np.all([np.array_equal(M[0], M[i]) for i in xrange(1,len(M))])
100 loops, best of 3: 10.6 ms per loop
>>> %timeit all(np.array_equal(M[0], M[i]) for i in xrange(1,len(M)))
100000 loops, best of 3: 11.3 µs per loop

>>> M = arr = np.array([[2]*100 for _ in xrange(1000)])
>>> %timeit np.all(np.all(arr == arr[0,:], axis = 1))
1000 loops, best of 3: 330 µs per loop
>>> %timeit (np.diff(M, axis=0) == 0).all()
1000 loops, best of 3: 594 µs per loop
>>> %timeit np.all([np.array_equal(M[0], M[i]) for i in xrange(1,len(M))])
100 loops, best of 3: 9.51 ms per loop
>>> %timeit all(np.array_equal(M[0], M[i]) for i in xrange(1,len(M)))
100 loops, best of 3: 9.44 ms per loop

答案 2 :(得分:3)

值得一提的是above version不适用于多维数组。

例如:对于三维方形图像张量img [256,256,3],我们需要检查图像中是否有相同的RGB [256,256]层。 在这种情况下,我们需要使用broadcasting

(img == img[:, :, 0, np.newaxis]).all()

因为简单img[:, :, 0]给我们[256,256],但我们需要[256,256,1]来通过图层进行广播。