我该如何解决向后()得到意外的关键字参数'retain_variables'吗?

时间:2019-04-07 23:52:30

标签: python python-3.x python-2.7 pytorch

我在下面编写了以下代码,但出现此错误:

TypeError: backward() got an unexpected keyword argument 'retain_variables'

我的代码是:

def learn(self, batch_state, batch_next_state, batch_reward, batch_action):
    outputs = self.model(batch_state).gather(1, batch_action.unsqueeze(1)).squeeze(1)
    next_outputs = self.model(batch_next_state).detach().max(1)[0]
    target = self.gamma*next_outputs + batch_reward
    td_loss = F.smooth_l1_loss(outputs, target)
    self.optimizer.zero_grad()
    td_loss.backward(retain_variables = True)
    self.optimizer.step()

2 个答案:

答案 0 :(得分:0)

a_guest mentions in the comments:

  

应为keep_graph = True。

答案 1 :(得分:0)

我遇到了同样的问题。这个解决方案对我有用。

td_loss.backward(retain_graph = True)

有效。