有没有更好的方法来沿第一维将两个 Pytorch 张量相乘和求和?

时间:2021-03-05 01:24:00

标签: python pytorch tensor

我有两个 Pytorch 张量,ab,形状分别为 (S, M)(S, M, H)M 是我的批次维度。我想将两个张量相乘和求和,使输出的形状为 (M, H)。也就是说,我想计算 sa[s] * b[s] 上的总和。

例如,对于 S=2M=2H=3

>>> import torch
>>> S, M, H = 2, 2, 3
>>> a = torch.arange(S*M).view((S,M))
tensor([[0, 1],
        [2, 3]])
>>> b = torch.arange(S*M*H).view((S,M,H))
tensor([[[ 0,  1,  2],
         [ 3,  4,  5]],

        [[ 6,  7,  8],
         [ 9, 10, 11]]])

'''
DESIRED OUTPUT:
= [[0*[0, 1, 2] + 2*[6, 7, 8]], 
   [1*[3, 4, 5] + 3*[9, 10, 11]]]

= [[12, 14, 16],
   [30, 34, 38]]

note: shape is (2, 3) = (M, H)
'''

我找到了一种使用 torch.tensordot 的方法:

>>> output = torch.tensordot(a, b, ([0], [0]))
tensor([[[12, 14, 16],
         [18, 20, 22]],

        [[18, 22, 26],
         [30, 34, 38]]])
>>> output.shape
torch.Size([2, 2, 3]) # always (M, M, H)
>>> output = output[torch.arange(M), torch.arange(M), :]
tensor([[12, 14, 16],
        [30, 34, 38]])

但正如您所看到的,它进行了很多不必要的计算,我必须将与我相关的部分进行切片。

有没有更好的方法来做到这一点,而不涉及不必要的计算?

1 个答案:

答案 0 :(得分:2)

这应该有效:

(torch.unsqueeze(a, 2)*b).sum(axis=0)
>>> tensor([[12, 14, 16],
            [30, 34, 38]])