有效地计算数组最后一维的点积

时间:2019-10-21 10:43:00

标签: python-3.x numpy numpy-einsum

在多维ndarray的最后一维上计算点积的最快方法是什么?

目前我正在这样做:

import numpy as np

a=np.reshape(np.arange(90),[3,3,2,5])
b=np.reshape(np.arange(90),[3,3,2,5])
# for the sake of simplicity, a and b are the same for this example

ab=(a*b).sum(axis=-1)

我认为einsum在这里可能有用,但是我很难将其应用于我的情况。

谢谢!

1 个答案:

答案 0 :(得分:1)

让通用ndim数组沿最后一个轴求和-

np.einsum('...i,...i->...',a,b)

替代np.matmul-

np.matmul(a[...,None,:],b[...,None])[...,0,0]

注意:在Python 3.x上,np.matmul可以替换为@ operator