Pytorch中的截断时间反向传播(BPTT)

时间:2018-12-24 11:39:30

标签: pytorch backpropagation truncated

在pytorch中,我通过使用以下命令开始反向传播(穿越时间)来训练RNN / GRU / LSTM网络:

loss.backward()

当序列很长时,我想做一个截断时间反向传播,而不是使用整个序列的正常时间反向传播。

但是我在Pytorch API中找不到用于设置截断的BPTT的任何参数或函数。我想念吗?我应该在Pytorch中自己编写代码吗?

1 个答案:

答案 0 :(得分:0)

这里是一个例子:

for t in range(T):
   y = lstm(y)
   if T-t == k:
      out.detach()
out.backward()

因此,在此示例中,k是用于控制要展开的时间步的参数。