在多维ndarray的最后一维上计算点积的最快方法是什么?
目前我正在这样做:
import numpy as np
a=np.reshape(np.arange(90),[3,3,2,5])
b=np.reshape(np.arange(90),[3,3,2,5])
# for the sake of simplicity, a and b are the same for this example
ab=(a*b).sum(axis=-1)
我认为einsum
在这里可能有用,但是我很难将其应用于我的情况。
谢谢!
答案 0 :(得分:1)
让通用ndim数组沿最后一个轴求和-
np.einsum('...i,...i->...',a,b)
替代np.matmul
-
np.matmul(a[...,None,:],b[...,None])[...,0,0]
注意:在Python 3.x上,np.matmul
可以替换为@ operator
。