跨多个轴的点积

时间:2017-07-28 11:16:27

标签: python arrays numpy dot-product

给定两个numpy数组,其中第一个d维度的大小相等

import numpy

d = 3
a = numpy.random.rand(2, 2, 2, 12, 3)
b = numpy.random.rand(2, 2, 2, 5)

我想计算这些第一维的点积。此

a2 = a.reshape(-1, *a.shape[d:])
b2 = b.reshape(-1, *b.shape[d:])
out = numpy.dot(numpy.moveaxis(a2, 0, -1), numpy.moveaxis(b2, 0, -2))

有效,但仅当b不具有(2, 2, 2)形状时才有效。与reshapemoveaxis混在一起似乎也比必要更复杂。

有更优雅的解决方案吗? (也许是tensordot?)

2 个答案:

答案 0 :(得分:2)

再次使用np.einsum

np.einsum('ijklm,ijkn->lmn',a,b)

答案 1 :(得分:1)

事实证明,tensordot 毕竟是有用的。此

numpy.tensordot(a, b, axes=(range(3), range(3)))

诀窍。