我有一个k x (n+k-1)
的pytorch张量w
张量requires_grad=True
。我也想将它转换为kxn
张量p
。我该怎么做,以至于最后通过对p[i] = w[i][i:i+n]
的损失函数调用backward()
,我将学会p
?
答案 0 :(得分:1)
任何类型的索引操作都可以执行,后向功能为<CopySlices>
一种简单的方法是使用简单的python索引:
w_unrolled = torch.zeros(p.size())
for i in range(w.shape[0]):
w_unrolled[i] = w[i][i:i+n]
loss = criterion(w_unrolled, p)
然后,您可以通过任意轴上的均值/和来减少损失。请注意,尽管这将起作用,但效率不高;最佳方法是使用本机索引功能来加快处理速度。