在Pytorch中重复张量的特定列

时间:2019-12-07 14:28:32

标签: python pytorch

我有一个大小为X的pytorch张量m x n和一个长度为num_repeats的非负整数n列表(假定sum(num_repeats)> 0)。在forward()方法中,我想创建一个张量X_dup,其大小为m x sum(num_repeats),其中i的列X被重复num_repeats[i]次。张量X_dup将在forward()方法的下游使用,因此需要正确地反向传播梯度。 我能想到的所有解决方案都需要就地操作(创建新的张量并通过遍历num_repeats来填充它),但是如果我理解正确,将无法保留梯度(如果我做错了,请纠正我,我是整个Pytorch的新手。)

1 个答案:

答案 0 :(得分:1)

假设您使用的是PyTorch> = 1.1.0,则可以使用torch.repeat_interleave

repeat_tensor = torch.tensor(num_repeats).to(X.device, torch.int64)
X_dup = torch.repeat_interleave(X, repeat_tensor, dim=1)