从SQL有效地读取Python稀疏矩阵

时间:2014-09-19 14:05:37

标签: python mysql scipy sparse-matrix

我在MySQL中有一个包含三列的表:row-index,column-index和value,我想将其读入scipy csr_matrix。我使用Python-MySQL连接器。有112,500个非零元素。

尝试1:

A = csr_matrix((N_rows, N_cols), dtype=float)
show = 'SELECT * FROM my_table'
cursor.execute(show)
for (row, col, value) in cursor:
     A[row, col] = value

这太慢了,我不得不在60秒后停止它。它提到了效率警告,并建议使用lil矩阵。

尝试2:

A = lil_matrix((N_rows, N_cols), dtype=float)
show = 'SELECT * FROM my_table'
cursor.execute(show)
for (row, col, value) in cursor:
     A[row, col] = value
A = csr_matrix(A)

这需要6.4秒(平均三个)。这是不是很好,或者是否有更快的方法可以创建csr_matrix而无需通过循环?如果我执行cursor.fetchall(),数据看起来像:

[(row_0, col_0, value_0), (row_1, col_1, value_1), ...]

这不能用于csr_matrix构造函数。

1 个答案:

答案 0 :(得分:4)

cursor.fetchall()返回的数据几乎是coo_matrix格式。 你可以做到

import numpy as np
from scipy.sparse import coo_matrix

data = cursor.fetchall()
#data = [(1, 2, 1.2), (3, 4, 7.1)]

arr = np.array(data, dtype=[('row', int), ('col', int), ('value', float)])
spmat = coo_matrix((arr['value'], (arr['row'], arr['col'])))

而不是np.array(cursor.fetchall(), ...),您也可以优先使用

arr = np.fromiter(cursor, dtype=[('row', int), ('col', int), ('value', float)])

将数据从DB直接加载到Numpy数组中。