在python中对矩阵对角线位置的行条目求和的快速方法

时间:2017-10-17 00:22:25

标签: python-2.7 python-3.x numpy scipy

您好我正在尝试解决下面的等式,其中A是稀疏矩阵,ptotal是数字数组。我必须在对角线位置连续汇总所有条目。

A[ptotal, ptotal] = -sum(A[ptotal, :])

代码似乎给出了正确答案,但由于我的ptotal数组差不多(100000个条目),因此计算效率不高。有没有快速的方法来解决这个问题。

1 个答案:

答案 0 :(得分:0)

首先是密集阵列版本:

In [87]: A = np.arange(36).reshape(6,6)
In [88]: ptotal = np.arange(6)

假设ptotal是所有行索引,可以用sum方法调用替换它:

In [89]: sum(A[ptotal,:])
Out[89]: array([ 90,  96, 102, 108, 114, 120])
In [90]: A.sum(axis=0)
Out[90]: array([ 90,  96, 102, 108, 114, 120])

我们可以在对角线上创建一个包含这些值的数组:

In [92]: np.diagflat(A.sum(axis=0))
Out[92]: 
array([[ 90,   0,   0,   0,   0,   0],
       [  0,  96,   0,   0,   0,   0],
       [  0,   0, 102,   0,   0,   0],
       [  0,   0,   0, 108,   0,   0],
       [  0,   0,   0,   0, 114,   0],
       [  0,   0,   0,   0,   0, 120]])

将其添加到原始数组中 - 结果是“零和”#39;阵列:

In [93]: A -= np.diagflat(A.sum(axis=0))
In [94]: A
Out[94]: 
array([[-90,   1,   2,   3,   4,   5],
       [  6, -89,   8,   9,  10,  11],
       [ 12,  13, -88,  15,  16,  17],
       [ 18,  19,  20, -87,  22,  23],
       [ 24,  25,  26,  27, -86,  29],
       [ 30,  31,  32,  33,  34, -85]])
In [95]: A.sum(axis=0)
Out[95]: array([0, 0, 0, 0, 0, 0])

我们可以用稀疏

做同样的事情
In [99]: M = sparse.csr_matrix(np.arange(36).reshape(6,6))
In [100]: M
Out[100]: 
<6x6 sparse matrix of type '<class 'numpy.int32'>'
    with 35 stored elements in Compressed Sparse Row format>
In [101]: M.sum(axis=0)
Out[101]: matrix([[ 90,  96, 102, 108, 114, 120]], dtype=int32)

稀疏对角矩阵:

In [104]: sparse.dia_matrix((M.sum(axis=0),0),M.shape)
Out[104]: 
<6x6 sparse matrix of type '<class 'numpy.int32'>'
    with 6 stored elements (1 diagonals) in DIAgonal format>
In [105]: _.A
Out[105]: 
array([[ 90,   0,   0,   0,   0,   0],
       [  0,  96,   0,   0,   0,   0],
       [  0,   0, 102,   0,   0,   0],
       [  0,   0,   0, 108,   0,   0],
       [  0,   0,   0,   0, 114,   0],
       [  0,   0,   0,   0,   0, 120]], dtype=int32)

采取不同之处,获得一个新的矩阵:

In [106]: M-sparse.dia_matrix((M.sum(axis=0),0),M.shape)
Out[106]: 
<6x6 sparse matrix of type '<class 'numpy.int32'>'
    with 36 stored elements in Compressed Sparse Row format>
In [107]: _.A
Out[107]: 
array([[-90,   1,   2,   3,   4,   5],
       [  6, -89,   8,   9,  10,  11],
       [ 12,  13, -88,  15,  16,  17],
       [ 18,  19,  20, -87,  22,  23],
       [ 24,  25,  26,  27, -86,  29],
       [ 30,  31,  32,  33,  34, -85]], dtype=int32)

还有setdiag方法

In [117]: M.setdiag(-M.sum(axis=0).A1)
/usr/local/lib/python3.5/dist-packages/scipy/sparse/compressed.py:774: SparseEfficiencyWarning: Changing the sparsity structure of a csr_matrix is expensive. lil_matrix is more efficient.
  SparseEfficiencyWarning)
In [118]: M.A
Out[118]: 
array([[ -90,    1,    2,    3,    4,    5],
       [   6,  -96,    8,    9,   10,   11],
       [  12,   13, -102,   15,   16,   17],
       [  18,   19,   20, -108,   22,   23],
       [  24,   25,   26,   27, -114,   29],
       [  30,   31,   32,   33,   34, -120]], dtype=int32)

Out[101]是一个2d矩阵; .A1将其转换为setdiag可以使用的1d数组。

稀疏效率警告更多地针对迭代使用而不是像这样的一次性应用。不过,看一下setdiag代码,我怀疑第一种方法更快。但我们确实需要做时间测试。