在Pytorch中连接两个具有不同尺寸的张量

时间:2018-05-19 10:16:52

标签: python pytorch

是否可以在不使用for循环的情况下连接两个具有不同尺寸的张量。

e.g。张量1具有尺寸(15,200,2048),张量2具有尺寸(1,200,2048)。是否有可能沿第一张量的第一维度的所有15个指数(沿张量1的第一维广播第二张量,同时沿第一张量的第三维连接)连接第二张量和第一张量?得到的张量应该具有尺寸(15,200,4096)。

是否可以在没有for循环的情况下完成此操作?

1 个答案:

答案 0 :(得分:5)

您可以在连接之前手动进行广播(使用Tensor.expand())(使用torch.cat()):

import torch

a = torch.randn(15, 200, 2048)
b = torch.randn(1, 200, 2048)

repeat_vals = repeat_vals = [a.shape[0] // b.shape[0]] + [-1] * (len(b.shape) - 1)
# or directly repeat_vals = (15, -1, -1) or (15, 200, 2048) if shapes are known and fixed...
res = torch.cat((a, b.expand(*repeat_vals)), dim=-1)
print(res.shape)
# torch.Size([15, 200, 4096])