用火炬获得形状(4,1,84,84)

时间:2020-04-05 15:04:59

标签: python pytorch

假设我有四个pytorch张量(tensor1, tensor2, tensor3, tensor4)。每个张量的形状为(1, 1, 84, 84)。第一个维度是张量的数量,第二个维度是颜色的数量(例如在我们的示例中为灰度),最后两个维度代表图像的高度和宽度。

我想堆叠它们,以得到形状(4, 1, 84, 84)

我尝试过torch.stack((tensor1, tensor2, tensor3, tensor4), dim=0),但是形状却是(4, 1, 1, 84, 84)

如何堆叠这些张量,以便形状为(4, 1, 84, 84)

1 个答案:

答案 0 :(得分:5)

您可以使用连接功能:

a = torch.ones(1,1,84,84)
b = torch.ones(1,1,84,84)
c = torch.cat((a,b), 0) # size[2,1,84,84]