我在MySQL中有一个包含三列的表:row-index,column-index和value,我想将其读入scipy csr_matrix。我使用Python-MySQL连接器。有112,500个非零元素。
尝试1:
A = csr_matrix((N_rows, N_cols), dtype=float)
show = 'SELECT * FROM my_table'
cursor.execute(show)
for (row, col, value) in cursor:
A[row, col] = value
这太慢了,我不得不在60秒后停止它。它提到了效率警告,并建议使用lil矩阵。
尝试2:
A = lil_matrix((N_rows, N_cols), dtype=float)
show = 'SELECT * FROM my_table'
cursor.execute(show)
for (row, col, value) in cursor:
A[row, col] = value
A = csr_matrix(A)
这需要6.4秒(平均三个)。这是不是很好,或者是否有更快的方法可以创建csr_matrix而无需通过循环?如果我执行cursor.fetchall(),数据看起来像:
[(row_0, col_0, value_0), (row_1, col_1, value_1), ...]
这不能用于csr_matrix构造函数。
答案 0 :(得分:4)
cursor.fetchall()
返回的数据几乎是coo_matrix格式。
你可以做到
import numpy as np
from scipy.sparse import coo_matrix
data = cursor.fetchall()
#data = [(1, 2, 1.2), (3, 4, 7.1)]
arr = np.array(data, dtype=[('row', int), ('col', int), ('value', float)])
spmat = coo_matrix((arr['value'], (arr['row'], arr['col'])))
而不是np.array(cursor.fetchall(), ...)
,您也可以优先使用
arr = np.fromiter(cursor, dtype=[('row', int), ('col', int), ('value', float)])
将数据从DB直接加载到Numpy数组中。