如何在不使用优化器的情况下将渐变设置为零?

时间:2019-02-12 10:33:48

标签: gradient pytorch

在互斥体.backward()之间,我想将梯度设置为零。现在,我必须分别对每个组件执行此操作(这里是xt),有没有办法针对所有受影响的变量“全局”执行此操作? (我想像是z.set_all_gradients_to_zero()之类的东西。)

我知道如果您使用优化器,就有optimizer.zero_grad(),但是还有一种不使用优化器的直接方法吗?

import torch

x = torch.randn(3, requires_grad = True)
t = torch.randn(3, requires_grad = True)
y = x + t
z = y + y.flip(0)

z.backward(torch.tensor([1., 0., 0.]), retain_graph = True)
print(x.grad)
print(t.grad)
x.grad.data.zero_()  # both gradients need to be set to zero 
t.grad.data.zero_()
z.backward(torch.tensor([0., 1., 0.]), retain_graph = True)
print(x.grad)
print(t.grad)

1 个答案:

答案 0 :(得分:1)

您也可以使用nn.Module.zero_grad()。实际上,optim.zero_grad()只是对传递给它的所有参数调用nn.Module.zero_grad()

没有合理的方法在全球范围内进行。您可以将变量收集在列表中

grad_vars = [x, t]
for var in grad_vars:
    var.grad.data = None

或基于vars()创建一些黑客功能。也许也可以检查计算图并将所有叶节点的梯度归零,但是我对图API并不熟悉。简而言之,您应该使用torch.nn的面向对象的接口,而不是手动创建张量变量。