为什么在下面的代码中没有禁用y
的梯度计算?
x = torch.randn(3, requires_grad=True)
print(x.requires_grad)
print((x ** 2).requires_grad)
y = x**2
print(y.requires_grad)
with torch.no_grad():
print((x ** 2).requires_grad)
print(y.requires_grad)
哪个给出以下输出:
True
True
True
False
True
答案 0 :(得分:0)
仔细阅读官方文档后,即使输入的内容为require_grad=False
,结果也将为required_grad=True
当您确定时,禁用梯度计算对于推断很有用 您不会调用:meth:
Tensor.backward()
。它将减少内存 否则将具有requires_grad=True
的计算的消耗。 在这种模式下,每次计算的结果将具有requires_grad=False
,即使输入中有requires_grad=True
。
答案 1 :(得分:0)
我不知道torch.no_grad()
的具体实现,但是文档中包含句子 每次计算的结果 ,这意味着它仅适用于结果,但不是原始变量。
运行以下代码:
with torch.no_grad():
print(x.grad)
这将给出输出:
True
因此,y
不是result
上下文中出现的torch.no_grad()
。