CSR矩阵和ndarray的点积更快

时间:2019-02-08 19:32:34

标签: python numpy scipy

我正在尝试获得一个快速的点积函数,以将稀疏矩阵(3 * 3)和nd数组(1 * 3)相乘,使得矩阵的每一行都得到与nd数组的点积,从而得到一个(3 * 1)个数组。

我目前的实现方式是获取矩阵的每一行,然后进行点积运算,但是要扩大矩阵尺寸,它会变得太慢。

row = np.array([0, 0, 1, 2, 2, 2])
col = np.array([0, 2, 2, 0, 1, 2])
data = np.array([1, 2, 3, 4, 5, 6])
matrix =csr_matrix((data, (row, col)), shape=(3, 3))
secondArray=np.random.rand((1, 3))

for idx, x in enumerate(matrix):
    X_arr=X.getrow(idx).toarray()
    prod=np.dot(np.array(X_arr[0]), secondArray)

0 个答案:

没有答案