我正在尝试使用Numba和它的JIT编译器在Python中加速一些稀疏矩阵 - 矩阵乘法。不幸的是,它不支持我需要的SciPy库。
我的解决方案是将函数csr_matmat_pass1()
和csr_matmat_pass2()
从here转换为Python代码。我的代码似乎适用于小于~80x80的矩阵,并提供正确的结果。这是我的解决方案:
import scipy.sparse as sparse
import numpy as np
def csr_matmat_pass1(n_row, n_col, Ap, Aj, Bp, Bj):
mask = np.ones(n_col, dtype="int") * -1
Cp = np.zeros(n_row+1, dtype="int")
nnz = 0
for i in range(n_row):
row_nnz = 0
for jj in range(Ap[i],Ap[i+1]):
j = Aj[jj]
for kk in range(Bp[j],Bp[j+1]):
k = Bj[kk]
if(mask[k] != i):
mask[k] = i
row_nnz += 1
next_nnz = nnz + row_nnz;
nnz = next_nnz;
Cp[i+1] = nnz;
return Cp
def csr_matmat_pass2(n_row, n_col, Ap, Aj, Ax, Bp, Bj, Bx, Cp):
nextV = np.ones(n_col, dtype="int") * -1
sums = np.zeros(n_col)
nnz = 0
Cp[0] = 0
#preallocate space
sizeC = max(len(Ax),len(Bx))
Cj = np.zeros(sizeC, dtype="int")
Cx = np.zeros(sizeC)
for i in range(n_row):
head = -2
length = 0
jj_start = Ap[i]
jj_end = Ap[i+1]
for jj in range(jj_start,jj_end):
j = Aj[jj]
v = Ax[jj]
kk_start = Bp[j]
kk_end = Bp[j+1]
for kk in range(kk_start,kk_end):
k = Bj[kk]
sums[k] += v*Bx[kk]
if(nextV[k] == -1):
nextV[k] = head
head = k
length += 1
for jj in range(length):
if(sums[head] != 0.0):
Cj[nnz] = head
Cx[nnz] = sums[head]
nnz += 1
temp = head
head = nextV[head]
nextV[temp] = -1
sums[temp] = 0
Cp[i+1] = nnz
return Cp, Cj, Cx
#calculate random sparse matrices A,B
mSize = 50
A = sparse.random(mSize, mSize, 0.01).tocsr()
B = sparse.random(mSize, mSize, 0.01).tocsr()
#calculate sparse C
Cp = csr_matmat_pass1(np.shape(A)[0], np.shape(B)[1], A.indptr, A.indices, B.indptr, B.indices)
Cp, Cj, Cx = csr_matmat_pass2(np.shape(A)[0], np.shape(B)[1], A.indptr, A.indices, A.data, B.indptr, B.indices, B.data, Cp)
#generate numpy sparse matrix from Cx, Cj, Cp
C = sparse.csr_matrix((Cx,Cj,Cp),shape=(nRow,nCol))
diffC = A.dot(B) - C
#validate function -> check if any nonzero element is present. If so -> calc is wrong
if np.any(diffC.todense()): UserWarning('Calculations are wrong')
当增加矩阵的大小时(假设为mSize=100
),我收到以下错误:
line 168, in csr_matmat_pass2 Cj[nnz] = head
IndexError: index 72 is out of bounds for axis 0 with size 72
我认为错误是在我的python转换中而不是在C ++代码中(因为它来自scipy库)。此外,Cp
的条目数大于矩阵A
,B
的大小。因此,csr_matmat_pass1()
的翻译必定存在错误。不幸的是,我找不到任何语法错误,也不知道为什么nnz
变得比它应该大。