假设A
,B
和C
是任何格式的稀疏矩阵。而且我想知道如何有效地计算他们的点积仅适用于C
中非零的元素。
等于
prod = A.dot(B)
prod[C == 0] = 0
用于Python中的密集矩阵。但是这段代码效率极低。
你可以告诉我什么吗?记忆并不重要。答案 0 :(得分:0)
您可以将C
强制转换为布尔值,并利用值True
在乘法上下文中变为1
这一事实。然后,您可以在A.dot(B)
和C
的产品之间进行元素相乘。
您可以通过以下方式实现这一目标:
A.dot(B).to_csr().multiple(C.to_csr())
为了获得最快的点积,我会粗略搜索scipy
提供的所有稀疏格式。定义计时功能,例如:
from functools import wraps
from time import time
def timing(f):
@wraps(f)
def wrapper(*args, **kwargs):
start = time()
result = f(*args, **kwargs)
end = time()
print 'Elapsed time: {}'.format(end-start)
return result
return wrapper
@timing
def csr_dot(a, b):
# Write similar functions for all other formats
return a.to_csr().dot(b.to_csr())
# This will print some time. Repeat for other formats.
csr_dot(A, B)
然后,您可以选择产生最佳时间的格式。