我有一个PyTorch张量元组列表。看起来像这样:
[
(tensor([1, 2, 3]), tensor([4, 5, 6, 7]), tensor([8])),
(tensor([9, 10,11]), tensor([11,12,13,14]), tensor([15])),
(tensor([16,17,18]), tensor([19,20,21,22]), tensor([23])),
...
]
每列中的张量(即,位于其各自元组的k处的张量)具有相同的形状。我想将张量堆叠在每列中,以便最终得到一个元组,每个值都是沿着列维连接的张量。
在这种情况下,输出元组将具有三个值,如下所示:
(
tensor([[1,2,3], [9,10,11], [16,17,18]]),
tensor([[4,5,6,7], [11,12,13,14], [19,20,21,22]],
tensor([[8],[15],[23])
)
这是一个虚构的示例。我想对任何长度的元组和任意大小的张量执行此操作。使用PyTorch快速进行这种串联的最佳方法是什么?
答案 0 :(得分:0)
如果有人陷入相同的混乱局面,我就能用一个可爱的单线解决它:
tuple(map(torch.tensor, zip(*x)))
在这种情况下,x
是我上面提到的原始列表。这行代码将x
转换为所需的确切格式。