numpy einsum:嵌套点产品

时间:2016-07-16 17:45:25

标签: python numpy numpy-einsum

我有两个n - by - k - by - 3数组ab,例如,

import numpy as np

a = np.array([
    [
        [1, 2, 3],
        [3, 4, 5]
    ],
    [
        [4, 2, 4],
        [1, 4, 5]
    ]
    ])
b = np.array([
    [
        [3, 1, 5],
        [0, 2, 3]
    ],
    [
        [2, 4, 5],
        [1, 2, 4]
    ]
    ])

并且它想计算所有“三元组”对的点积,即

np.sum(a*b, axis=2)

更好的方法可能是einsum,但我似乎无法直接得到指数。

这里有任何提示吗?

1 个答案:

答案 0 :(得分:3)

您正在丢失这两个3D输入数组上的第三个轴并减少该总和,同时保持前两个轴对齐。因此,对于np.einsum,我们将使前两个字符串相同,第三个字符串也是相同的,但是在输出字符串表示法信号中将被跳过,我们正在沿着该轴减少两个输入。因此,解决方案是 -

np.einsum('ijk,ijk->ij',a,b)