我有一个代码,我需要在矩阵之间进行大量的乘法操作。该代码旨在用于任意维度n的2D矩阵,其原则上可能非常大,使得程序非常慢。 到目前为止,为了操作乘法,我总是使用np.dot,如下例
def getV(csi, e, e2, k):
ktrans = k.transpose()
v = np.dot(csi, ktrans)
v = np.dot(v, e)
v = np.dot(v, k)
v = np.dot(v, csi)
v = np.dot(v, ktrans)
e2trans = e2.transpose()
v = np.dot(v, e2trans)
v = np.dot(v, k)
traceV = 2*v.trace()
return traceV
其中输出应该是产品痕迹的两倍:
csi*ktrans*e*k*csi*ktrans*e2trans*k
(它们都是矩阵相乘)。 我相信有一种更快的方法可以制作这么长的产品,可能在一个段落中。有人可以解释一下吗?我试过了,但似乎np.dot在任何一个段落中总是只需要两个矩阵。
答案 0 :(得分:3)
由于properties of the trace这个计算可以重写如下,这将矩阵乘法的数量从7减少到4:
def getV(csi, k, e, e2):
temp = k.dot(csi).dot(k.T)
trace_ = (temp.dot(e).dot(temp) * e2).sum()
return 2 * trace_
根据您当前的设置,您还可以尝试安装不同的BLAS库或在显卡而不是CPU上进行计算。