为什么PyTorch需要保留图形?

时间:2018-09-28 17:17:17

标签: python pytorch

我像这样训练我的模型:

for i in range(5):
  optimizer.zero_grad()
  y = next_input()
  loss = model(y)
  loss.backward()
  optimizer.step()

并得到此错误

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

为什么要求我保留图表?如果释放它们,它可能只是重新计算衍生物。为了证明这一点,请考虑以下代码:

for i in range(5):
  optimizer.zero_grad()
  model.zero_grad() # drop derivatives
  y = next_input()
  loss = model(y)
  loss.backward(retain_graph=True)
  optimizer.step()

在这种情况下,上一次迭代的导数也被清零,但是Torch不在乎,因为设置了标志retain_graph=True

我对model.zero_grad()抵消了retain_graph=True的影响(即丢弃保留的导数)吗?

2 个答案:

答案 0 :(得分:0)

您需要在每次通过之后将梯度归零,以告诉割炬抛出先前累积的值,否则,每次调用loss.backward()时,它将在整个计算中向后传播。

因此代码应为

for i in range(5):
  y = next_input()
  loss = model(y)
  loss.backward()
  optimizer.step()
  optimizer.zero_grad()

由于您不将梯度归零,因此pytorch尝试通过先前的计算进行反向传播,这就是为什么它会给您带来关于保留图形的错误的原因。如果保留图形,则基本上不会丢弃先前步骤的累积梯度。

This discussion on pytorch forum可能会对您有所帮助。它强调了我上面提到的同一件事

答案 1 :(得分:0)

由于所讨论的渐变是模型的渐变,因此正确的代码应为model.zero_grad()。我不确定optimizer.zero_grad()是否可以工作,因为我从未尝试过。您的第一个示例是:

for i in range(5):
  model.zero_grad()  # instead of optimizer.zero_grad()
  x, y = next_input_output_pair()  # We get both input and expected output
  loss = mean_squared_error(model(x), y)  # the loss is calculated
  loss.backward()  # backward calculation
  optimizer.step()