Python将dict与csr_matrices作为值进行比较

时间:2015-01-06 07:58:46

标签: python numpy dictionary sparse-matrix unit-testing

我有两个变量re,它们都是字典,字符串为键,csr_matrices为值。现在我想断言他们是平等的。我该怎么做?

尝试1:

from scipy.sparse.csr import csr_matrix
import numpy as np

def test_dict_equals(self):
    r = {'a': csr_matrix([[0, 0 ,1], [0, 1, 0], [1, 0, 0]])}
    e = {'a': csr_matrix([[0, 0 ,1], [0, 1, 0], [1, 0, 0]])}
    self.assertDictEqual(r, e)

这不起作用:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all().

尝试2:

def test_dict_equals(self):
    r = {'a': csr_matrix([[0, 0 ,1.01], [0, 1, 0], [1, 0, 0]])}
    e = {'a': csr_matrix([[0, 0 ,1.01], [0, 1, 0], [1, 0, 0]])}
    self.assertListEqual(r.keys(), e.keys())
    for k in r.keys():
        np.testing.assert_allclose(r[k], e[k])

这也行不通:

AssertionError: First sequence is not a list: dict_keys(['a'])

尝试3:

def test_dict_equals(self):
    r = {'a': csr_matrix([[0, 0 ,1.01], [0, 1, 0], [1, 0, 0]])}
    e = {'a': csr_matrix([[0, 0 ,1.01], [0, 1, 0], [1, 0, 0]])}
    self.assertListEqual(list(r.keys()), list(e.keys()))
    for k in r.keys():
        np.testing.assert_allclose(r[k], e[k])

这也行不通:

TypeError: ufunc 'isinf' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''

2 个答案:

答案 0 :(得分:1)

assertDictEqual函数将调用对象的__eq__方法。在csr_matrix的源代码中,您可以看到没有__eq__方法。

您必须编写csr_matrix的子类,然后执行断言。以下是numpy.ndarray的示例。代码必须类似。

import copy
import numpy
import unittest

class SaneEqualityArray(numpy.ndarray):
    def __eq__(self, other):
        return (isinstance(other, SaneEqualityArray) and
                self.shape == other.shape and
                numpy.ndarray.__eq__(self, other).all())

class TestAsserts(unittest.TestCase):

    def testAssert(self):
        tests = [
            [1, 2],
            {'foo': 2},
            [2, 'foo', {'d': 4}],
            SaneEqualityArray([1, 2]),
            {'foo': {'hey': SaneEqualityArray([2, 3])}},
            [{'foo': SaneEqualityArray([3, 4]), 'd': {'doo': 3}},
             SaneEqualityArray([5, 6]), 34]
        ]
        for t in tests:
            self.assertEqual(t, copy.deepcopy(t))

if __name__ == '__main__':
    unittest.main()

希望它有所帮助。:)

答案 1 :(得分:1)

暂时忘掉字典,并专注于比较2个sparse矩阵。它们不是numpy数组,因此您无法直接使用np方法。这就是为什么你的第三次尝试不起作用。

有一个scipy.sparse单元测试目录。我还没有检查过它,但它可能会给你超出我建议的想法。

https://github.com/scipy/scipy/tree/master/scipy/sparse/tests

A=sparse.csr_matrix(np.arange(9).reshape(3,3))
B=sparse.csr_matrix(np.arange(9).reshape(3,3))

它们是不同的对象

id(A)==id(B) # False

他们拥有相同数量的非零

A.nnz == B.nnz  # True - just a comparison of 2 numbers

此稀疏格式的数据包含在3个数组A.dataA.indicesA.indptr中。因此,您可以使用np方法来测试其中一个或多个

np.allclose(A.data, B.data)   # this would also compare dtype

您还可以比较形状等。

较新版本的scipy已经为稀疏矩阵实现了逐元素比较器。 ==已实施但可能会给您一个警告:

  

SparseEfficiencyWarning:使用==比较稀疏矩阵是低效的,请尝试使用!=而不是。

如果形状匹配,这可能是与稀疏矩阵进行比较的有效方法:

(A!=B).nnz==0 

如果形状不匹配,A!=C会返回True

如果它们很小,你可以比较它们的密集等价物:

np.allclose(A.A, B.A)