在张量元组列表中堆叠张量

时间:2019-12-03 02:13:45

标签: python pytorch torch

我有一个PyTorch张量元组列表。看起来像这样:

[
    (tensor([1, 2, 3]),  tensor([4, 5, 6, 7]),  tensor([8])),
    (tensor([9, 10,11]), tensor([11,12,13,14]), tensor([15])),
    (tensor([16,17,18]), tensor([19,20,21,22]), tensor([23])),
    ...
]

每列中的张量(即,位于其各自元组的k处的张量)具有相同的形状。我想将张量堆叠在每列中,以便最终得到一个元组,每个值都是沿着列维连接的张量。

在这种情况下,输出元组将具有三个值,如下所示:

(
 tensor([[1,2,3], [9,10,11], [16,17,18]]),

 tensor([[4,5,6,7], [11,12,13,14], [19,20,21,22]],

 tensor([[8],[15],[23])
)

这是一个虚构的示例。我想对任何长度的元组和任意大小的张量执行此操作。使用PyTorch快速进行这种串联的最佳方法是什么?

1 个答案:

答案 0 :(得分:0)

如果有人陷入相同的混乱局面,我就能用一个可爱的单线解决它:

tuple(map(torch.tensor, zip(*x)))

在这种情况下,x是我上面提到的原始列表。这行代码将x转换为所需的确切格式。