我正在训练一个简单的LSTM模型,但是pytorch给我一个错误,提示我需要设置keep_graph = True。但是,这需要花费更长的时间来训练模型,我认为我不需要这样做。
class SequenceModel(nn.Module):
def __init__(self):
super().__init__()
self.lstm = nn.LSTM(input_size = 3, hidden_size = 3, bidirectional=False)
self.hidden = (torch.randn(1, 1, 3).double(), torch.randn(1, 1, 3).double())
def forward(self,x):
lstm_out, self.hidden = self.lstm(x.view(-1, 1, 3),self.hidden)
return lstm_out
def loss(self,logits,labels):
return F.cross_entropy(logits, labels)
model = SequenceModel()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
model = model.double()
model.train()
epochs = 1000
for epoch in tqdm(range(epochs)):
optimizer.zero_grad()
logits = model(inputs)
logits = logits.reshape(-1,3)
loss = model.loss(logits,outputs.long())
loss.backward()
optimizer.step()
我得到的错误是:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
但是我不想将keep_graph设置为True。