如何在张量中复制输入通道?

时间:2020-02-04 13:41:03

标签: pytorch tensor

我有一个torch.Size([39, 1, 20, 256, 256])形状的张量,如何复制通道以制成torch.Size([39, 3, 20, 256, 256])形状。

1 个答案:

答案 0 :(得分:4)

我相当确定这已经是一个重复的问题,但是我自己找不到合适的答案,这就是为什么我继续通过同时引用PyTorch documentation和{{3}来回答这个问题的原因}

本质上,torch.Tensor.expand()是您要寻找的功能,可以按以下方式使用:

x = torch.rand([39, 1, 20, 256, 256])
y = x.expand(39, 3, 20, 256, 256)

请注意,这仅适用于单个尺寸,在您的示例中就是这种情况,但可能不适用于扩展前的任意尺寸。另外,这基本上只是提供了一个不同的内存视图,这意味着根据文档,您必须牢记以下几点:

扩展张量中的多个元素可以引用单个 内存位置。结果,就地操作(尤其是那些 向量化)可能会导致错误的行为。如果你需要 写张量,请先克隆它们。

有关新分配的内存版本,请参见PyTorch forum中概述的torch.Tensor.repeat。语法的其他方面与expand()完全相同。