如何在PyTorch中粘贴块张量列表以形成更大的张量

时间:2019-02-19 22:38:58

标签: pytorch torch

假定存在一个小张量(例如16个块)列表,并且希望沿水平和垂直方向粘贴这些小张量以创建较大的2D图像。

torch.split()可用于将张量分割为较小的块,是否有相反情况的操作?

谢谢

2 个答案:

答案 0 :(得分:0)

您正在使用torch.cat寻找dim。要垂直堆叠张量,请使用torch.cat(..., dim=0),要水平堆叠张量,请使用torch.cat(..., dim=1)

示例

tensors = torch.split(torch.randn(4, 6), 2, dim=1)

tensors
(tensor([[-1.0257,  0.5213],
         [-0.1181, -1.4420],
         [-1.5563, -1.0757],
         [ 1.1788,  0.6222]]), tensor([[-0.4531, -0.1260],
         [-0.2383, -1.3542],
         [-0.8752, -0.4728],
         [ 0.7879,  1.3686]]), tensor([[ 2.3357, -0.6220],
         [ 0.2687,  0.1146],
         [ 0.9912, -0.0586],
         [-0.8507,  0.5126]]))

沿第一维垂直堆叠:

torch.cat(tensors, dim=0)
tensor([[-1.0257,  0.5213],
        [-0.1181, -1.4420],
        [-1.5563, -1.0757],
        [ 1.1788,  0.6222],
        [-0.4531, -0.1260],
        [-0.2383, -1.3542],
        [-0.8752, -0.4728],
        [ 0.7879,  1.3686],
        [ 2.3357, -0.6220],
        [ 0.2687,  0.1146],
        [ 0.9912, -0.0586],
        [-0.8507,  0.5126]])

沿第二维水平堆叠:

torch.cat(tensors, dim=1)
tensor([[-1.0257,  0.5213, -0.4531, -0.1260,  2.3357, -0.6220],
        [-0.1181, -1.4420, -0.2383, -1.3542,  0.2687,  0.1146],
        [-1.5563, -1.0757, -0.8752, -0.4728,  0.9912, -0.0586],
        [ 1.1788,  0.6222,  0.7879,  1.3686, -0.8507,  0.5126]])

答案 1 :(得分:0)

只需对矩阵运算稍作改动即可完成构建块矩阵。假设您要堆叠到4个矩阵 A B C D

  

A B

     

C D

a = torch.tensor([[1,2],[3,4]])
tensor([[1, 2],
        [3, 4]])
b = torch.tensor([[5,6],[7,8]])
tensor([[5, 6],
        [7, 8]])
c = torch.tensor([[-1,-2],[-3,-4]])
tensor([[-1, -2],
        [-3, -4]])
d = torch.tensor([[-5,-6],[-7,-8]])
tensor([[-5, -6],
        [-7, -8]])

然后连接,转置并分成两个块。 (需要转置,因为我们稍后会转置,而这会取消)。

x,y = torch.cat((a,b,c,d),dim=1).t().chunk(2)
(tensor([[1, 3],
         [2, 4],
         [5, 7],
         [6, 8]]), 
 tensor([[-1, -3],
         [-2, -4],
         [-5, -7],
         [-6, -8]]))

接下来将这两个矩阵并排放置并转置

torch.cat((x,y),dim=1).t()
tensor([[ 1,  2,  5,  6],
        [ 3,  4,  7,  8],
        [-1, -2, -5, -6],
        [-3, -4, -7, -8]])

概括为 N x N 块矩阵应该很简单。