我有一个大小为X
的pytorch张量m x n
和一个长度为num_repeats
的非负整数n
列表(假定sum(num_repeats)> 0)。在forward()方法中,我想创建一个张量X_dup
,其大小为m x sum(num_repeats)
,其中i
的列X
被重复num_repeats[i]
次。张量X_dup
将在forward()方法的下游使用,因此需要正确地反向传播梯度。
我能想到的所有解决方案都需要就地操作(创建新的张量并通过遍历num_repeats
来填充它),但是如果我理解正确,将无法保留梯度(如果我做错了,请纠正我,我是整个Pytorch的新手。)
答案 0 :(得分:1)
假设您使用的是PyTorch> = 1.1.0,则可以使用torch.repeat_interleave
。
repeat_tensor = torch.tensor(num_repeats).to(X.device, torch.int64)
X_dup = torch.repeat_interleave(X, repeat_tensor, dim=1)