我有两个变量r
和e
,它们都是字典,字符串为键,csr_matrices为值。现在我想断言他们是平等的。我该怎么做?
尝试1:
from scipy.sparse.csr import csr_matrix
import numpy as np
def test_dict_equals(self):
r = {'a': csr_matrix([[0, 0 ,1], [0, 1, 0], [1, 0, 0]])}
e = {'a': csr_matrix([[0, 0 ,1], [0, 1, 0], [1, 0, 0]])}
self.assertDictEqual(r, e)
这不起作用:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all().
尝试2:
def test_dict_equals(self):
r = {'a': csr_matrix([[0, 0 ,1.01], [0, 1, 0], [1, 0, 0]])}
e = {'a': csr_matrix([[0, 0 ,1.01], [0, 1, 0], [1, 0, 0]])}
self.assertListEqual(r.keys(), e.keys())
for k in r.keys():
np.testing.assert_allclose(r[k], e[k])
这也行不通:
AssertionError: First sequence is not a list: dict_keys(['a'])
尝试3:
def test_dict_equals(self):
r = {'a': csr_matrix([[0, 0 ,1.01], [0, 1, 0], [1, 0, 0]])}
e = {'a': csr_matrix([[0, 0 ,1.01], [0, 1, 0], [1, 0, 0]])}
self.assertListEqual(list(r.keys()), list(e.keys()))
for k in r.keys():
np.testing.assert_allclose(r[k], e[k])
这也行不通:
TypeError: ufunc 'isinf' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''
答案 0 :(得分:1)
assertDictEqual
函数将调用对象的__eq__
方法。在csr_matrix
的源代码中,您可以看到没有__eq__
方法。
您必须编写csr_matrix
的子类,然后执行断言。以下是numpy.ndarray
的示例。代码必须类似。
import copy
import numpy
import unittest
class SaneEqualityArray(numpy.ndarray):
def __eq__(self, other):
return (isinstance(other, SaneEqualityArray) and
self.shape == other.shape and
numpy.ndarray.__eq__(self, other).all())
class TestAsserts(unittest.TestCase):
def testAssert(self):
tests = [
[1, 2],
{'foo': 2},
[2, 'foo', {'d': 4}],
SaneEqualityArray([1, 2]),
{'foo': {'hey': SaneEqualityArray([2, 3])}},
[{'foo': SaneEqualityArray([3, 4]), 'd': {'doo': 3}},
SaneEqualityArray([5, 6]), 34]
]
for t in tests:
self.assertEqual(t, copy.deepcopy(t))
if __name__ == '__main__':
unittest.main()
希望它有所帮助。:)
答案 1 :(得分:1)
暂时忘掉字典,并专注于比较2个sparse
矩阵。它们不是numpy
数组,因此您无法直接使用np
方法。这就是为什么你的第三次尝试不起作用。
有一个scipy.sparse
单元测试目录。我还没有检查过它,但它可能会给你超出我建议的想法。
https://github.com/scipy/scipy/tree/master/scipy/sparse/tests
A=sparse.csr_matrix(np.arange(9).reshape(3,3))
B=sparse.csr_matrix(np.arange(9).reshape(3,3))
它们是不同的对象
id(A)==id(B) # False
他们拥有相同数量的非零
A.nnz == B.nnz # True - just a comparison of 2 numbers
此稀疏格式的数据包含在3个数组A.data
,A.indices
,A.indptr
中。因此,您可以使用np
方法来测试其中一个或多个
np.allclose(A.data, B.data) # this would also compare dtype
您还可以比较形状等。
较新版本的scipy
已经为稀疏矩阵实现了逐元素比较器。 ==
已实施但可能会给您一个警告:
SparseEfficiencyWarning:使用==比较稀疏矩阵是低效的,请尝试使用!=而不是。
如果形状匹配,这可能是与稀疏矩阵进行比较的有效方法:
(A!=B).nnz==0
如果形状不匹配,A!=C
会返回True
如果它们很小,你可以比较它们的密集等价物:
np.allclose(A.A, B.A)