scipy.sparse:将行设置为零

时间:2012-08-26 12:01:39

标签: row scipy sparse-matrix

假设我有一个CSR格式的矩阵,将行(或行)设置为零的最有效方法是什么?

以下代码运行缓慢:

A = A.tolil()
A[indices, :] = 0
A = A.tocsr()

我必须转换为scipy.sparse.lil_matrix,因为CSR格式似乎既不支持花哨的索引,也不支持将值设置为切片。

4 个答案:

答案 0 :(得分:6)

我猜scipy只是没有实现它,但CSR格式会很好地支持这一点,请阅读维基百科关于“稀疏矩阵”的文章关于indptr等等:

# A.indptr is an array, one for each row (+1 for the nnz):

def csr_row_set_nz_to_val(csr, row, value=0):
    """Set all nonzero elements (elements currently in the sparsity pattern)
    to the given value. Useful to set to 0 mostly.
    """
    if not isinstance(csr, scipy.sparse.csr_matrix):
        raise ValueError('Matrix given must be of CSR format.')
    csr.data[csr.indptr[row]:csr.indptr[row+1]] = value

# Now you can just do:
for row in indices:
    csr_row_set_nz_to_val(A, row, 0)

# And to remove zeros from the sparsity pattern:
A.eliminate_zeros()

当然,这将从稀疏模式中删除从eliminate_zeros的另一个地方设置的0。如果你想这样做(在这一点上)取决于你在做什么,即。消除可能有意义延迟,直到所有其他可能添加新零的计算也完成,或者在某些情况下你可能有0个值,你想稍后再次更改,所以消除它们是非常糟糕的!

原则上你当然可以将eliminate_zerosprune短路,但这应该是很麻烦的,甚至可能更慢(因为你不会在C中这样做)


有关eliminiate_zeros(和修剪)的详细信息

稀疏矩阵通常不会保存零元素,而只是存储非零元素的位置(粗略地和各种方法)。 eliminate_zeros从稀疏模式中删除矩阵中的所有零(即没有存储该位置的值,当 存储一个vlaue之前,但它为0时)。如果你想稍后将0更改为不同的值,则消除是不好的,否则会节省空间。

Prune只会缩小存储的数据数组,只要它们需要更长时间。请注意,当我第一次使用A.prune()时,A.eliminiate_zeros()已包含修剪。

答案 1 :(得分:0)

更新到scipy的最新版本。它支持花式索引。

答案 2 :(得分:0)

您可以使用矩阵点积来实现归零。由于我们将使用的矩阵非常稀疏(对于行/列的对角线,我们将其归零),乘法应该是有效的。

您需要以下功能之一:

import scipy.sparse

def zero_rows(M, rows):
    diag = scipy.sparse.eye(M.shape[0]).tolil()
    for r in rows:
        diag[r, r] = 0
    return diag.dot(M)

def zero_columns(M, columns):
    diag = scipy.sparse.eye(M.shape[1]).tolil()
    for c in columns:
        diag[c, c] = 0
    return M.dot(diag)

用法示例:

>>> A = scipy.sparse.csr_matrix([[1,0,3,4], [5,6,0,8], [9,10,11,0]])
>>> A
<3x4 sparse matrix of type '<class 'numpy.int64'>'
        with 9 stored elements in Compressed Sparse Row format>
>>> A.toarray()
array([[ 1,  0,  3,  4],
       [ 5,  6,  0,  8],
       [ 9, 10, 11,  0]], dtype=int64)

>>> B = zero_rows(A, [1])
>>> B
<3x4 sparse matrix of type '<class 'numpy.float64'>'
        with 6 stored elements in Compressed Sparse Row format>
>>> B.toarray()
array([[  1.,   0.,   3.,   4.],
       [  0.,   0.,   0.,   0.],
       [  9.,  10.,  11.,   0.]])

>>> C = zero_columns(A, [1, 3])
>>> C
<3x4 sparse matrix of type '<class 'numpy.float64'>'
        with 5 stored elements in Compressed Sparse Row format>
>>> C.toarray()
array([[  1.,   0.,   3.,   0.],
       [  5.,   0.,   0.,   0.],
       [  9.,   0.,  11.,   0.]])

答案 3 :(得分:0)

我想完成answer given by @seberg。如果要将nnz值设置为零,则应该修改CSR矩阵的结构,而不仅仅是修改.data属性。

此代码的当前行为是

>>> import scipy.sparse
>>> import numpy as np
>>> A = scipy.sparse.csr_matrix([[0,1,0], [2,0,3], [0,0,0], [4,0,0]])
>>> A.toarray()
array([[0, 1, 0],
       [2, 0, 3],
       [0, 0, 0],
       [4, 0, 0]], dtype=int64)
>>> csr_row_set_nz_to_val(A, 1)
>>> A.toarray()
array([[0, 1, 0],
       [0, 0, 0],
       [0, 0, 0],
       [4, 0, 0]], dtype=int64)
>>> A.data
array([1, 0, 0, 4], dtype=int64)
>>> A.indices
array([1, 0, 2, 0], dtype=int32)
>>> A.indptr
array([0, 1, 3, 3, 4], dtype=int32)

由于我们正在处理稀疏矩阵,因此我们不希望A.data数组中为零。我认为应该将csr_row_set_nz_to_val修改如下

def csr_row_set_nz_to_val(csr, row, value=0):
    """Set all nonzero elements of a CSR matrix M (elements currently in the sparsity pattern)
    to the given value. Useful to set to 0 mostly.
    """
    if not isinstance(csr, scipy.sparse.csr_matrix):
        raise ValueError("Matrix given must be of CSR format.")
    if value == 0:
        csr.data = np.delete(csr.data, range(csr.indptr[row], csr.indptr[row+1])) # drop nnz values
        csr.indices = np.delete(csr.indices, range(csr.indptr[row], csr.indptr[row+1])) # drop nnz column indices
        csr.indptr[(row+1):] = csr.indptr[(row+1):] - (csr.indptr[row+1] - csr.indptr[row])
    else:
        csr.data[csr.indptr[row]:csr.indptr[row+1]] = value # replace nnz values by another nnz value

最后,我们会得到

>>> A = scipy.sparse.csr_matrix([[0,1,0], [2,0,3], [0,0,0], [4,0,0]])
>>> csr_row_set_nz_to_val(A, 1)
>>> A.toarray()
array([[0, 1, 0],
       [0, 0, 0],
       [0, 0, 0],
       [4, 0, 0]], dtype=int64)
>>> A.data
array([1, 4], dtype=int64)
>>> A.indices
array([1, 0], dtype=int32)
>>> A.indptr
array([0, 1, 1, 1, 2], dtype=int32)