pyTorch可以向后两次,而无需设置keep_graph = True

时间:2018-09-23 06:02:16

标签: pytorch autograd

pyTorch tutorial所示,

  

如果您甚至想对图形的某些部分进行两次反向操作,   您需要在第一次通过时传递keep_graph = True。

但是,我发现以下代码片段实际上没有起作用。我正在使用pyTorch-0.4

x = torch.ones(2, 2, requires_grad=True)
y = x + 2
y.backward(torch.ones(2, 2)) # Note I do not set retain_graph=True
y.backward(torch.ones(2, 2)) # But it can still work!
print x.grad

输出:

tensor([[ 2.,  2.], 
        [ 2.,  2.]]) 

有人可以解释吗?预先感谢!

1 个答案:

答案 0 :(得分:4)

在您的情况下,它不使用retain_graph=True的原因是您有一个非常简单的图形,该图形可能没有内部中间缓冲区,进而没有缓冲区会被释放,因此无需使用retain_graph=True

但是,在图形中再添加一个额外的计算时,一切都会改变:

代码:

x = torch.ones(2, 2, requires_grad=True)
v = x.pow(3)
y = v + 2

y.backward(torch.ones(2, 2))

print('Backward 1st time w/o retain')
print('x.grad:', x.grad)

print('Backward 2nd time w/o retain')

try:
    y.backward(torch.ones(2, 2))
except RuntimeError as err:
    print(err)

print('x.grad:', x.grad)

输出:

Backward 1st time w/o retain
x.grad: tensor([[3., 3.],
                [3., 3.]])
Backward 2nd time w/o retain
Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
x.grad: tensor([[3., 3.],
                [3., 3.]]).

在这种情况下,将计算其他内部v.grad,但是torch不存储中间值(中间渐变等),并且使用retain_graph=False v.grad将被释放首先backward之后。

因此,如果您想第二次反向传播,则需要指定retain_graph=True来“保持”图形。

代码:

x = torch.ones(2, 2, requires_grad=True)
v = x.pow(3)
y = v + 2

y.backward(torch.ones(2, 2), retain_graph=True)

print('Backward 1st time w/ retain')
print('x.grad:', x.grad)

print('Backward 2nd time w/ retain')

try:
    y.backward(torch.ones(2, 2))
except RuntimeError as err:
    print(err)
print('x.grad:', x.grad)

输出:

Backward 1st time w/ retain
x.grad: tensor([[3., 3.],
                [3., 3.]])
Backward 2nd time w/ retain
x.grad: tensor([[6., 6.],
                [6., 6.]])