在pytorch中,我通过使用以下命令开始反向传播(穿越时间)来训练RNN / GRU / LSTM网络:
loss.backward()
当序列很长时,我想做一个截断时间反向传播,而不是使用整个序列的正常时间反向传播。
但是我在Pytorch API中找不到用于设置截断的BPTT的任何参数或函数。我想念吗?我应该在Pytorch中自己编写代码吗?
答案 0 :(得分:0)
这里是一个例子:
for t in range(T):
y = lstm(y)
if T-t == k:
out.detach()
out.backward()
因此,在此示例中,k
是用于控制要展开的时间步的参数。