变换损失函数的Pytorch参数矩阵

时间:2018-12-08 13:38:53

标签: pytorch

我有一个k x (n+k-1)的pytorch张量w张量requires_grad=True。我也想将它转换为kxn张量p。我该怎么做,以至于最后通过对p[i] = w[i][i:i+n]的损失函数调用backward(),我将学会p

1 个答案:

答案 0 :(得分:1)

任何类型的索引操作都可以执行,后向功能为<CopySlices> 一种简单的方法是使用简单的python索引:

w_unrolled = torch.zeros(p.size())
for i in range(w.shape[0]):
    w_unrolled[i] = w[i][i:i+n]
loss = criterion(w_unrolled, p)

然后,您可以通过任意轴上的均值/和来减少损失。请注意,尽管这将起作用,但效率不高;最佳方法是使用本机索引功能来加快处理速度。