如何在pytorch中返回中间梯度(对于非叶子节点)?

时间:2019-03-22 17:51:06

标签: python gradient pytorch register-hook

我的问题是关于pytorch register_hook的语法。

x = torch.tensor([1.], requires_grad=True)
y = x**2
z = 2*y

x.register_hook(print)
y.register_hook(print)

z.backward()

输出:

tensor([2.])
tensor([4.])

此代码段仅分别打印zx的渐变。

现在,我的(最可能是琐碎的)问题是如何返回中间渐变(而不是仅打印)?

更新:

看来,调用y解决了叶子节点的问题。例如retain_grad()

但是,y.retain_grad()似乎无法解决非叶子节点的问题。有什么建议吗?

1 个答案:

答案 0 :(得分:0)

我认为您可以使用这些挂钩将梯度存储在全局变量中:

grads = []
x = torch.tensor([1.], requires_grad=True)
y = x**2 + 1
z = 2*y

x.register_hook(lambda d:grads.append(d))
y.register_hook(lambda d:grads.append(d))

z.backward()

但是您很可能还需要记住计算这些梯度所对应的张量。在这种情况下,我们使用dict而不是list稍微扩展一下:

grads = {}
x = torch.tensor([1.,2.], requires_grad=True)
y = x**2 + 1
z = 2*y

def store(grad,parent):
    print(grad,parent)
    grads[parent] = grad.clone()

x.register_hook(lambda grad:store(grad,x))
y.register_hook(lambda grad:store(grad,y))

z.sum().backward()

例如,现在您可以使用y来访问张量grads[y]的等级