两个张量的Pytorch广播产品

时间:2018-06-22 07:08:12

标签: python deep-learning matrix-multiplication pytorch tensor

我想将两个张量相乘,这就是我得到的:

  • A形状的张量(20, 96, 110)
  • B形状的张量(20, 16, 110)

第一个索引用于批次大小。 我想要做的基本上是从B-(20, 1, 110)中获取每个张量,然后,我想将每个A张量(20, n, 110)乘以。 因此乘积将在末尾:张量AB,其形状为(20, 96 * 16, 110)

所以我想通过A广播来将B中的每个张量相乘。 PyTorch中有可以做到这一点的方法吗?

1 个答案:

答案 0 :(得分:1)

使用torch.einsum,然后使用torch.reshape

AB = torch.einsum("ijk,ilk->ijlk", (A, B)).reshape(A.shape[0], -1, A.shape[2])

示例:

import numpy as np
import torch

# A of shape (2, 3, 2):
A = torch.from_numpy(np.array([[[1, 1], [2, 2], [3, 3]], 
                               [[4, 4], [5, 5], [6, 6]]]))
# B of shape (2, 2, 2):
B = torch.from_numpy(np.array([[[1, 1], [10, 10]], 
                               [[2, 2], [20, 20]]]))

# AB of shape (2, 3*2, 2):
AB = torch.einsum("ijk,ilk->ijlk", (A, B)).reshape(A.shape[0], -1, A.shape[2])
# tensor([[[ 1, 1], [ 10, 10], [  2,  2], [ 20,   20], [ 3,   3], [ 30,  30]],
#         [[ 8, 8], [ 80, 80], [ 10, 10], [ 100, 100], [ 12, 12], [ 120, 120]]])