np.dot用于2D矩阵之间的多个产品

时间:2015-03-19 15:17:00

标签: python performance numpy matrix multiplication

我有一个代码,我需要在矩阵之间进行大量的乘法操作。该代码旨在用于任意维度n的2D矩阵,其原则上可能非常大,使得程序非常慢。 到目前为止,为了操作乘法,我总是使用np.dot,如下例

def getV(csi, e, e2, k):
    ktrans = k.transpose()
    v = np.dot(csi, ktrans)
    v = np.dot(v, e)
    v = np.dot(v, k)
    v = np.dot(v, csi)
    v = np.dot(v,  ktrans)
    e2trans = e2.transpose()
    v = np.dot(v, e2trans)
    v = np.dot(v, k)
    traceV = 2*v.trace()
    return traceV

其中输出应该是产品痕迹的两倍:

csi*ktrans*e*k*csi*ktrans*e2trans*k

(它们都是矩阵相乘)。 我相信有一种更快的方法可以制作这么长的产品,可能在一个段落中。有人可以解释一下吗?我试过了,但似乎np.dot在任何一个段落中总是只需要两个矩阵。

1 个答案:

答案 0 :(得分:3)

由于properties of the trace这个计算可以重写如下,这将矩阵乘法的数量从7减少到4:

def getV(csi, k, e, e2):
    temp = k.dot(csi).dot(k.T)
    trace_ = (temp.dot(e).dot(temp) * e2).sum()
    return 2 * trace_

根据您当前的设置,您还可以尝试安装不同的BLAS库或在显卡而不是CPU上进行计算。