在no_grad()PyTorch中未禁用渐变计算

时间:2019-05-13 12:04:23

标签: pytorch

为什么在下面的代码中没有禁用y的梯度计算?

x = torch.randn(3, requires_grad=True)
print(x.requires_grad)
print((x ** 2).requires_grad)
y = x**2
print(y.requires_grad)
with torch.no_grad():
    print((x ** 2).requires_grad)
    print(y.requires_grad)

哪个给出以下输出:

True
True
True
False
True

2 个答案:

答案 0 :(得分:0)

仔细阅读官方文档后,即使输入的内容为require_grad=False,结果也将为required_grad=True

  

当您确定时,禁用梯度计算对于推断很有用       您不会调用:meth:Tensor.backward()。它将减少内存       否则将具有requires_grad=True的计算的消耗。       在这种模式下,每次计算的结果将具有       requires_grad=False,即使输入中有requires_grad=True

答案 1 :(得分:0)

我不知道torch.no_grad()的具体实现,但是文档中包含句子 每次计算的结果 ,这意味着它仅适用于结果,但不是原始变量。 运行以下代码:

with torch.no_grad():
    print(x.grad)

这将给出输出:

True

因此,y不是result上下文中出现的torch.no_grad()