numpy.einsum表达式的含义/等效项

时间:2019-09-16 20:43:39

标签: python numpy numpy-einsum

我正拼命寻找与以下numpy.einsum表达式等效的python内置函数:

>>> a = np.array((((1, 2), (3, 4)), ((5, 6), (7, 8))))
>>> a
array([[[1, 2],
        [3, 4]],

       [[5, 6],
        [7, 8]]])

>>> b = np.array((((9, 10), (11, 12)), ((13, 14), (15, 16))))
>>> b
array([[[ 9, 10],
        [11, 12]],

       [[13, 14],
        [15, 16]]])

>>> np.einsum("abc,abd->dc", a, b)
array([[212, 260],
       [228, 280]])

1 个答案:

答案 0 :(得分:0)

@AlexRiley评论中的直接翻译是这样的:

(a[...,None,:]*b[...,None]).sum((0,1))

让我们解析规范字符串“ abc,abd-> dc”,然后将术语重命名为x和y,以免它们与索引冲突:

这被读取为结果 dc = ∑ ab x abc y abd

如您所见,索引是从规范字符串中逐字获取的。将结果规范中未出现的指标相加。就是这样。

侧面说明:我们可以做得更好:合并前两个轴,可以将表达式读为numpy使用高度优化的代码路径的矩阵乘积:

b.reshape(-1,b.shape[-1]).T@a.reshape(-1,a.shape[-1])

这是直接翻译的两倍多,也比原始einsum快一点。