稀疏矩阵点积仅保留每个结果行的N-max值

时间:2015-04-15 10:04:39

标签: python numpy scipy linear-algebra sparse-matrix

我有一个非常庞大的csr稀疏矩阵M。我希望得到此矩阵的点积(M.dot(M.T)),并且在结果矩阵N中每行仅保留R个最大值。问题是点积M.dot(M.T)会引发MemoryError。所以我创建了点函数的修改实现,如下所示:

def dot_with_top(m1, m2, top=None):
    if top is not None and top > 0:
        res_rows = []
        for row_id in xrange(m1.shape[0]):

            row = m1[row_id]
            if row.nnz > 0:
                res_row = m1[row_id].dot(m2)
                if res_row.nnz > top:
                    args_ids = np.argsort(res_row.data)[-top:]
                    data = res_row.data[args_ids]
                    cols = res_row.indices[args_ids]
                    res_rows.append(csr_matrix((data, (np.zeros(top), cols)), shape=res_row.shape))
                else:
                    res_rows.append(res_row)
            else:
                res_rows.append(csr_matrix((1, m1.shape[0])))
        return sparse.vstack(res_rows, 'csr')
    return m1.dot(m2) 

它工作正常,但有点慢。是否可以更快地进行此计算,或者您可能知道一些已经存在的方法可以更快地完成此计算?

1 个答案:

答案 0 :(得分:1)

您可以在函数的行数上实现循环,并使用multiprocessing.Pool()对象调用此函数。 这将并行化循环的执行,并应添加一个很好的加速。

示例:

from multiprocessing import Pool

def f(row_id): 
# define here your function inside the loop
    return vstack(res_rows, 'csr')

if __name__ == '__main__':
    p = Pool(4) # if you have 4 cores in your processor
    p.map(f, xrange(m1.shape[0]))

来源:https://docs.python.org/2/library/multiprocessing.html#using-a-pool-of-workers

请注意,某些python实现的函数已经使用了多处理(numpy中常见),因此在执行此解决方案之前,应该在脚本运行时检查处理器活动。