结果时的稀疏矩阵乘法'稀疏性已知(在python | scipy | cython中)

时间:2014-07-28 22:06:20

标签: python numpy matrix scipy cython

假设我们想为给定的稀疏矩阵A,B计算C = A * B但是对C的条目的一小部分感兴趣,由索引对列表表示:
rows = [i1,i2,i3 ...]
cols = [j1,j2,j3 ...]
A和B都很大(比如说50Kx50K),但非常稀疏(<1%的条目非零)。

我们如何计算乘法的这个子集?

这是一个非常简单的天真实现:

def naive(A, B, rows, cols):
    N = len(rows)
    vals = []
    for n in xrange(N):
        v = A.getrow(rows[n]) * B.getcol(cols[n])
        vals.append(v[0, 0])

    R = sps.coo_matrix((np.array(vals), (np.array(rows), np.array(cols))), shape=(A.shape[0], B.shape[1]), dtype=np.float64)
    return R

即使对于小型矩阵,这也是非常糟糕的:

import scipy.sparse as sps
import numpy as np
D = 1000

A = np.random.randn(D, D)
A[np.abs(A) > 0.1] = 0
A = sps.csr_matrix(A)
B = np.random.randn(D, D)
B[np.abs(B) > 0.1] = 0
B = sps.csr_matrix(B)

X = np.random.randn(D, D)
X[np.abs(X) > 0.1] = 0
X[X != 0] = 1
X = sps.csr_matrix(X)
rows, cols = X.nonzero()
naive(A, B, rows, cols)

在我的机器上,naive()在1分钟后完成,大部分工作都花在构造行/列上(在getrow(),getcol()中)。
当然,将这个(非常小的)示例转换为密集矩阵,计算大约需要100ms:

A0 = np.array(A.todense())
B0 = np.array(B.todense())
X0 = np.array(X.todense())
A0.dot(B0) * X0

关于如何有效地计算这种矩阵乘法的任何想法?

1 个答案:

答案 0 :(得分:4)

稀疏矩阵的格式在这里很重要。您总是需要一个A行和一个B行。因此,将A存储为csr而将B存储为csc以删除getrow / getcol开销。不幸的是,这只是故事的一小部分。

最好的解决方案很大程度上取决于稀疏矩阵的结构(很多稀疏列/行等),但您可以尝试基于字典和集合。对于每行的矩阵A,保留以下内容:

  • 该行上包含所有非零列索引的集合
  • 一个字典,其中非零索引作为键,相应的非零值作为值

对于矩阵B,每列保留相似的dicts和集。

要计算乘法结果中的元素(M,N),A的行M乘以B的列N.乘法:

  • 找到非零集合的集合交集
  • 计算非零元素(即上面的交点)的乘法和

在大多数情况下,这应该非常快,因为在稀疏矩阵中,集合交点通常非常小。

一些代码:

class rowarray():
    def __init__(self, arr):
        self.rows = []
        for row in arr:
            nonzeros = np.nonzero(row)[0]
            nzvalues = { i: row[i] for i in nonzeros }
            self.rows.append((set(nonzeros), nzvalues))

    def __getitem__(self, key):
        return self.rows[key]

    def __len__(self):
        return len(self.rows)


class colarray(rowarray):
    def __init__(self, arr):
        rowarray.__init__(self, arr.T)


def maybe_less_naive(A, B, rows, cols):
    N = len(rows)
    vals = []
    for n in xrange(N):
        nz1,v1 = A[rows[n]]
        nz2,v2 = B[cols[n]]
        # list of common non-zeros
        nz = nz1.intersection(nz2)
        # sum of non-zeros
        vals.append(sum([ v1[i]*v2[i] for i in nz]))

    R = sps.coo_matrix((np.array(vals), (np.array(rows), np.array(cols))), shape=(len(A), len(B)), dtype=np.float64)
    return R

D = 1000

Ap = np.random.randn(D, D)
Ap[np.abs(Ap) > 0.1] = 0
A = rowarray(Ap)
Bp = np.random.randn(D, D)
Bp[np.abs(Bp) > 0.1] = 0
B = colarray(Bp)

X = np.random.randn(D, D)
X[np.abs(X) > 0.1] = 0
X[X != 0] = 1
X = sps.csr_matrix(X)
rows, cols = X.nonzero()
maybe_less_naive(A, B, rows, cols)

这样效率更高,测试乘法大约需要2秒(80 000个元素)。结果似乎基本相同。


关于表现的一些评论。

为每个输出元素执行了两个操作:

  • 设置交叉点
  • 乘法

集合交集的复杂度应为O(min(m,n)),其中m和n是每个操作数中非零的数量。这与矩阵的大小不变,只有每行/每列的非零的平均数量很重要。

乘法(和dict查找)的数量取决于上面交叉点中找到的非零数。

如果两个矩阵都有随机分布的非零值,概率(密度)为p,行/列长度为n,那么:

  • 设置交集:O(np)
  • 字典查找,乘法:O(np ^ 2)

这表明,对于非常稀疏的矩阵,找到交叉点是关键点。这也可以通过剖析来验证;大部分时间用于计算交叉点。

当这反映在现实世界中时,我们似乎花了大约20美元用于80个非零的行/列。这不是非常快,代码当然可以更快。 Cython可能是一个解决方案,但这可能是Python不是最好的解决方案的问题之一。当用C语言编写时,排序整数的简单线性匹配(合并排序类型算法)应该至少快一个数量级。

需要注意的一件重要事情是,算法可以一次为多个元素并行完成。没有必要解决单个线程,因为只要一个线程处理一个输出点,计算就是独立的。