比较多个numpy数组

时间:2016-06-12 18:10:30

标签: python arrays numpy

我应该如何比较2个以上的numpy阵列?

import numpy 
a = numpy.zeros((512,512,3),dtype=numpy.uint8)
b = numpy.zeros((512,512,3),dtype=numpy.uint8)
c = numpy.zeros((512,512,3),dtype=numpy.uint8)
if (a==b==c).all():
     pass

这给出了一个valueError,我对一次比较两个数组不感兴趣。

4 个答案:

答案 0 :(得分:4)

对于三个数组,你可以检查第一个和第二个数组之间相应元素之间的相等性,然后检查第二个和第三个数组,给我们两个布尔标量,最后看看这两个标量是否都是Working on 09 - Groove #2 (instrumental studio outtake).flac Working on h Sand Castles and Glass Camels Unusual Occurrences in the Desert.flac Working on 12 - Proto 18 Proper (instrumental studio outtake).flac Working on r Stopped.flac Working on 13 - Hollywood Cantana (studio outtake).flac ... 标量输出,如此 -

True

对于更多数量的数组,您可以堆叠它们,沿着堆叠轴获得区别,并检查这些差异的全部是否等于零。如果是,我们在所有输入数组中都是相等的,否则就没有。实现看起来像这样 -

np.logical_and( (a==b).all(), (b==c).all() )

答案 1 :(得分:4)

对于三个数组,你应该一次只比较两个:

if np.array_equal(a, b) and np.array_equal(b, c):
    do_whatever()

对于可变数量的数组,我们假设它们全部组合成一个大数组arrays。然后你可以做

if np.all(arrays[:-1] == arrays[1:]):
    do_whatever()

答案 2 :(得分:0)

要扩展先前的答案,我将使用INSERT INTO TableA (TableBID, Column1, Column2) SELECT TableAID, Column1, Column2 FROM TableA_Staging WHERE TableAID > X 中的combinations构造所有对,然后在每个对上进行比较。例如,如果我有三个数组,并想确认它们都相等,我将使用:

itertools

答案 3 :(得分:0)

支持不同形状和结构的解决方案

与数组列表的第一个元素进行比较:

import numpy as np

a = np.arange(3)
b = np.arange(3)
c = np.arange(3)
d = np.arange(4)

lst_eq = [a, b, c]
lst_neq = [a, b, d]

def all_equal(lst):
    for arr in lst[1:]:
        if not np.array_equal(lst[0], arr, equal_nan=True):
            return False
    return True

print('all_equal(lst_eq)=', all_equal(lst_eq))
print('all_equal(lst_neq)=', all_equal(lst_neq))

输出

all_equal(lst_eq)= True
all_equal(lst_neq)= False

对于相同的形状并且没有nan-support

将所有内容组合成一个数组,计算沿新轴的绝对差异,并检查沿新维度的最大元素是否等于 0 或低于某个阈值。这应该很快。

import numpy as np

a = np.arange(3)
b = np.arange(3)
c = np.arange(3)
d = np.array([0, 1, 3])

lst_eq = [a, b, c]
lst_neq = [a, b, d]

def all_equal(lst, threshold = 0):
    arr = np.stack(lst, axis=0)

    return np.max(np.abs(np.diff(arr, axis=0))) <= threshold

print('all_equal(lst_eq)=', all_equal(lst_eq))
print('all_equal(lst_neq)=', all_equal(lst_neq))

输出

all_equal(lst_eq)= True
all_equal(lst_neq)= False