我正在使用缓冲区来传递 LSTM 网络的隐藏状态。
def __init__(self, model, hidden_state1=None, ...somethine else...):
self.register_buffer('hidden_state1', hidden_state1)
self.hidden_state1 = hidden_state1
....#other codes
为了避免错误:
RuntimeError: Trying to backward through the graph a second time,
but the buffers have already been freed.
Specify retain_graph=True when calling backward the first time.
我使用 .clone().detach()
来分离缓冲区。
无论如何我都需要手动分离它们,我还需要在 Pytorch 中使用缓冲区而不是普通参数吗?
带有“requires_grad=False”的普通参数是否足以替代缓冲区的使用?
(其实我也不知道这样传递隐藏状态是不是好方法)