计算PyTorch中间节点的梯度

时间:2018-01-01 16:43:41

标签: pytorch

我试图了解在PyTorch中autograd是如何工作的。在下面的简单程序中,我不明白为什么loss w.r.t W1W2的渐变为None据我从文档中了解,W1W2是易变的,因此无法计算渐变。 是吗?我的意思是,我怎么不能把损失w.r.t中间节点的衍生物?谁能解释一下我在这里缺少什么?

import torch
import torch.autograd as tau

W = tau.Variable(torch.FloatTensor([[0, 1]]), requires_grad=True)
a = tau.Variable(torch.FloatTensor([[2, 2]]), requires_grad=False)
b = tau.Variable(torch.FloatTensor([[3, 3]]), requires_grad=False)

W1 = W  + a * a
W2 = W1 - b * b * b
Z = W2 * W2

print 'W:', W
print 'W1:', W1
print 'W2:', W2
print 'Z:', Z

loss = torch.sum((Z - 3) * (Z - 3))
print 'loss:', loss

# free W gradient buffer in case you are running this cell more than 2 times
if W.grad is not None: W.grad.data.zero_()

loss.backward()
print 'W.grad:', W.grad

# all of them are None
print 'W1.grad:', W1.grad
print 'W2.grad:', W2.grad
print 'a.grad:', a.grad
print 'b.grad:', b.grad
print 'Z.grad:', Z.grad

1 个答案:

答案 0 :(得分:3)

如果需要,中间渐变会累积in a C++ buffer但是为了节省内存,默认情况下它们不会被保留(在python对象中公开)。 仅保留使用requires_grad=True设置的叶子变量的渐变(在您的示例中为W

保留中间渐变的一种方法是注册一个钩子。这项工作的一个钩子是retain_grad()see PR) 在您的示例中,如果您撰写W2.retain_grad()W2的中间渐变将在W2.grad

中公开

W1W2不易变(您可以通过访问他们的volatile属性(即:W1.volatile)进行检查)并且不能因为它们不是叶子变量(例如Wab)。相反,需要计算它们的渐变,请参见它们的requires_grad属性。 如果只有一个叶子变量为volatile,则不构造整个后向图(您可以通过制作一个易失性并查看损失梯度函数来检查)

a = tau.Variable(torch.FloatTensor([[2, 2]]), volatile=True)
# ...
assert loss.grad_fn is None

总结

  • 波动率表示无梯度计算:在推理模式下很有用
    • 只有一个叶变量设置为volatile 禁用渐变计算
  • 需要渐变意味着渐变计算。中间的是否暴露
    • 只有一个叶子变量需要渐变启用渐变计算
相关问题