当输入是许多相同的数组时,使np.einsum更快吗? (或其他更快的方法)

时间:2020-07-09 03:59:43

标签: python numpy numpy-einsum

我有一段代码:

nnt = np.real(np.einsum('xa,xb,yc,yd,abcde->exy',evec,evec,evec,evec,quartic))

其中evec是(例如)L x L np.float32数组,而quartic是L xL x L x L x T np.complex64数组。

我发现这个例程很慢。

我认为既然所有evec都相同,那么可能会有更快的方法吗?

谢谢。

1 个答案:

答案 0 :(得分:3)

首先,您可以重复使用第一个计算:

evec2 = np.real(np.einsum('xa,xb->xab',evec,evec))
nnt = np.real(np.einsum('xab,ycd,abcde->exy',evec2,evec2,quartic))

如果您不关心内存,只需要性能:

evec2 = np.real(np.einsum('xa,xb->xab',evec,evec))
nnt = np.real(np.einsum('xab,ycd,abcde->exy',evec2,evec2,quartic,optimize=True))