我有两个 Pytorch 张量,a
和 b
,形状分别为 (S, M)
和 (S, M, H)
。 M
是我的批次维度。我想将两个张量相乘和求和,使输出的形状为 (M, H)
。也就是说,我想计算 s
的 a[s] * b[s]
上的总和。
例如,对于 S=2
、M=2
、H=3
:
>>> import torch
>>> S, M, H = 2, 2, 3
>>> a = torch.arange(S*M).view((S,M))
tensor([[0, 1],
[2, 3]])
>>> b = torch.arange(S*M*H).view((S,M,H))
tensor([[[ 0, 1, 2],
[ 3, 4, 5]],
[[ 6, 7, 8],
[ 9, 10, 11]]])
'''
DESIRED OUTPUT:
= [[0*[0, 1, 2] + 2*[6, 7, 8]],
[1*[3, 4, 5] + 3*[9, 10, 11]]]
= [[12, 14, 16],
[30, 34, 38]]
note: shape is (2, 3) = (M, H)
'''
我找到了一种使用 torch.tensordot
的方法:
>>> output = torch.tensordot(a, b, ([0], [0]))
tensor([[[12, 14, 16],
[18, 20, 22]],
[[18, 22, 26],
[30, 34, 38]]])
>>> output.shape
torch.Size([2, 2, 3]) # always (M, M, H)
>>> output = output[torch.arange(M), torch.arange(M), :]
tensor([[12, 14, 16],
[30, 34, 38]])
但正如您所看到的,它进行了很多不必要的计算,我必须将与我相关的部分进行切片。
有没有更好的方法来做到这一点,而不涉及不必要的计算?
答案 0 :(得分:2)
这应该有效:
(torch.unsqueeze(a, 2)*b).sum(axis=0)
>>> tensor([[12, 14, 16],
[30, 34, 38]])