pytorch-lstm产生了rest_graph错误-我该如何解决?

时间:2020-03-03 09:09:50

标签: python pytorch

我正在训练一个简单的LSTM模型,但是pytorch给我一个错误,提示我需要设置keep_graph = True。但是,这需要花费更长的时间来训练模型,我认为我不需要这样做。

class SequenceModel(nn.Module):

def __init__(self):
    super().__init__()
    self.lstm = nn.LSTM(input_size = 3, hidden_size = 3, bidirectional=False)
    self.hidden = (torch.randn(1, 1, 3).double(), torch.randn(1, 1, 3).double())

def forward(self,x):
    lstm_out, self.hidden = self.lstm(x.view(-1, 1, 3),self.hidden)
    return lstm_out

def loss(self,logits,labels):
    return F.cross_entropy(logits, labels)

model = SequenceModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

model = model.double()

model.train()
epochs = 1000
for epoch in tqdm(range(epochs)):
    optimizer.zero_grad()


    logits = model(inputs)
    logits = logits.reshape(-1,3)
    loss = model.loss(logits,outputs.long())

    loss.backward() 
    optimizer.step()

我得到的错误是:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

但是我不想将keep_graph设置为True。

0 个答案:

没有答案