我的问题是关于pytorch register_hook
的语法。
x = torch.tensor([1.], requires_grad=True)
y = x**2
z = 2*y
x.register_hook(print)
y.register_hook(print)
z.backward()
输出:
tensor([2.])
tensor([4.])
此代码段仅分别打印z
和x
的渐变。
现在,我的(最可能是琐碎的)问题是如何返回中间渐变(而不是仅打印)?
更新:
看来,调用y
解决了叶子节点的问题。例如retain_grad()
。
但是,y.retain_grad()
似乎无法解决非叶子节点的问题。有什么建议吗?
答案 0 :(得分:0)
我认为您可以使用这些挂钩将梯度存储在全局变量中:
grads = []
x = torch.tensor([1.], requires_grad=True)
y = x**2 + 1
z = 2*y
x.register_hook(lambda d:grads.append(d))
y.register_hook(lambda d:grads.append(d))
z.backward()
但是您很可能还需要记住计算这些梯度所对应的张量。在这种情况下,我们使用dict
而不是list
稍微扩展一下:
grads = {}
x = torch.tensor([1.,2.], requires_grad=True)
y = x**2 + 1
z = 2*y
def store(grad,parent):
print(grad,parent)
grads[parent] = grad.clone()
x.register_hook(lambda grad:store(grad,x))
y.register_hook(lambda grad:store(grad,y))
z.sum().backward()
例如,现在您可以使用y
来访问张量grads[y]
的等级