我想断言两个Python字典是相等的(这意味着:等量的密钥,每个从键到值的映射是相等的;顺序并不重要)。一种简单的方法是assert A==B
,但是,如果词典的值为numpy arrays
,则不起作用。如果两个词典相同,我怎样才能编写一个函数来检查?
>>> import numpy as np
>>> A = {1: np.identity(5)}
>>> B = {1: np.identity(5) + np.ones([5,5])}
>>> A == B
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
编辑我知道应该检查numpy矩阵是否与.all()
相等。我正在寻找的是检查这一点的一般方法,而无需检查isinstance(np.ndarray)
。这有可能吗?
没有numpy数组的相关主题:
答案 0 :(得分:13)
您可以使用numpy.testing.assert_equal
http://docs.scipy.org/doc/numpy/reference/generated/numpy.testing.assert_equal.html
答案 1 :(得分:1)
我将回答隐藏在您的问题标题和上半部分中的一半问题,因为坦率地说,这是一个要解决的更为普遍的问题,现有的答案不能很好地解决。这个问题是“ 如何比较numpy数组的两个字典是否相等”?
问题的第一部分是“从远处”检查字典:看到它们的键是相同的。如果所有键都相同,则第二部分将比较每个对应的值。
现在,一个微妙的问题是很多numpy数组不是整数值,double-precision is imprecise。因此,除非您拥有整数值(或其他非浮点型)数组,否则您可能需要检查值是否几乎相同,即在机器精度内。因此,在这种情况下,您将不使用np.array_equal
(检查精确的数值相等性),而是使用np.allclose
(对两个数组之间的相对误差和绝对误差使用有限的公差)。
问题的前半部分很简单:检查字典的键是否一致,并使用生成器理解比较每个值(并在理解之外使用all
来验证每个项目一样):
import numpy as np
# some dummy data
# these are equal exactly
dct1 = {'a': np.array([2, 3, 4])}
dct2 = {'a': np.array([2, 3, 4])}
# these are equal _roughly_
dct3 = {'b': np.array([42.0, 0.2])}
dct4 = {'b': np.array([42.0, 3*0.1 - 0.1])} # still 0.2, right?
def compare_exact(first, second):
"""Return whether two dicts of arrays are exactly equal"""
if first.keys() != second.keys():
return False
return all(np.array_equal(first[key], second[key]) for key in first)
def compare_approximate(first, second):
"""Return whether two dicts of arrays are roughly equal"""
if first.keys() != second.keys():
return False
return all(np.allclose(first[key], second[key]) for key in first)
# let's try them:
print(compare_exact(dct1, dct2)) # True
print(compare_exact(dct3, dct4)) # False
print(compare_approximate(dct3, dct4)) # True
正如您在上面的示例中看到的那样,整数数组比较精确,并且根据您正在执行的操作(或者如果您很幸运),它甚至可以用于浮点数。但是,如果浮点数是任何算术运算(例如,线性转换?)的结果,则绝对应该使用近似检查。有关后一种选项的完整说明,请参见the docs of numpy.allclose
(及其元素级的朋友numpy.isclose
),并特别注意rtol
和atol
关键字参数。
答案 2 :(得分:-1)
考虑这段代码
>>> import numpy as np
>>> np.identity(5)
array([[ 1., 0., 0., 0., 0.],
[ 0., 1., 0., 0., 0.],
[ 0., 0., 1., 0., 0.],
[ 0., 0., 0., 1., 0.],
[ 0., 0., 0., 0., 1.]])
>>> np.identity(5)+np.ones([5,5])
array([[ 2., 1., 1., 1., 1.],
[ 1., 2., 1., 1., 1.],
[ 1., 1., 2., 1., 1.],
[ 1., 1., 1., 2., 1.],
[ 1., 1., 1., 1., 2.]])
>>> np.identity(5) == np.identity(5)+np.ones([5,5])
array([[False, False, False, False, False],
[False, False, False, False, False],
[False, False, False, False, False],
[False, False, False, False, False],
[False, False, False, False, False]], dtype=bool)
>>>
请注意,比较结果是矩阵,而不是布尔值。 Dict比较将使用值 cmp 方法比较值,这意味着在比较矩阵值时,dict比较将得到复合结果。你想要做的就是使用 numpy.all将复合数组结果折叠为标量布尔结果
>>> np.all(np.identity(5) == np.identity(5)+np.ones([5,5]))
False
>>> np.all(np.identity(5) == np.identity(5))
True
>>>
您需要编写自己的函数来比较这些词典,测试值类型以查看它们是否为matricies,然后使用numpy.all
进行比较,否则使用==
。当然,如果你也愿意,你可以随时获得想象并开始继承dict并重载 cmp 。