pytorch如何设置.requires_grad False

时间:2018-08-08 13:36:47

标签: python pytorch gradient-descent

我想冻结一些模型。遵循官方文档:

with torch.no_grad():
    linear = nn.Linear(1, 1)
    linear.eval()
    print(linear.weight.requires_grad)

但是它打印True而不是False。如果我想将模型设置为评估模式,该怎么办?

5 个答案:

答案 0 :(得分:12)

requires_grad = False

如果要冻结部分模型并训练其余模型,可以将要冻结的参数requires_grad设置为False

例如,如果您只想固定VGG16的卷积部分:

model = torchvision.models.vgg16(pretrained=True)
for param in model.features.parameters():
    param.requires_grad = False

通过将requires_grad标志切换为False,将不会保存任何中间缓冲区,直到计算达到某个操作的输入之一需要渐变的点为止。

torch.no_grad()

使用上下文管理器torch.no_grad是实现该目标的另一种方式:在no_grad上下文中,所有计算结果都将具有requires_grad=False,即使输入具有{ {1}}。请注意,您将无法将渐变反向传播到requires_grad=True之前的图层。例如:

no_grad

输出:

x = torch.randn(2, 2)
x.requires_grad = True

lin0 = nn.Linear(2, 2)
lin1 = nn.Linear(2, 2)
lin2 = nn.Linear(2, 2)
x1 = lin0(x)
with torch.no_grad():    
    x2 = lin1(x1)
x3 = lin2(x2)
x3.sum().backward()
print(lin0.weight.grad, lin1.weight.grad, lin2.weight.grad)

此处(None, None, tensor([[-1.4481, -1.1789], [-1.4481, -1.1789]])) 为True,但未计算梯度,因为操作是在lin1.weight.requires_grad上下文中完成的。

model.eval()

如果您的目标不是微调,而是将模型设置为推理模式,则最方便的方法是使用no_grad上下文管理器。在这种情况下,您还必须将模型设置为评估模式,这可以通过在torch.no_grad上调用eval()来实现,例如:

nn.Module

此操作将图层的属性model = torchvision.models.vgg16(pretrained=True) model.eval() 设置为self.training,实际上,这将更改FalseDropout之类的操作的行为,这些行为在训练时必须表现出不同和测试时间。

答案 1 :(得分:5)

这是路;

linear = nn.Linear(1,1)

for param in linear.parameters():
    param.requires_grad = False

with torch.no_grad():
    linear.eval()
    print(linear.weight.requires_grad)

输出:错误

答案 2 :(得分:3)

要完成@Salih_Karagoz的答案,您还具有torch.set_grad_enabled()上下文(更多文档here),可用于轻松地在训练/评估模式之间切换:

linear = nn.Linear(1,1)

is_train = False
with torch.set_grad_enabled(is_train):
    linear.eval()
    print(linear.weight.requires_grad)

答案 3 :(得分:0)

tutorial可能会有所帮助。

简而言之,我认为解决这个问题的好方法可能是:

linear = nn.Linear(1,1)

for param in linear.parameters():
    param.requires_grad = False

linear.eval()
print(linear.weight.requires_grad)

答案 4 :(得分:0)

好。诀窍是检查您在定义线性定律时,默认情况下参数是否具有requires_grad=True,因为我们想学习,对吧?

l = nn.Linear(1, 1)
p = l.parameters()
for _ in p:
    print (_)

# Parameter containing:
# tensor([[-0.3258]], requires_grad=True)
# Parameter containing:
# tensor([0.6040], requires_grad=True)    

另一个构造

with torch.no_grad():

意味着您无法在此处学习。

因此,即使您处于torch.no_grad()禁止学习的地方,您的代码也仅显示您具有学习能力。

with torch.no_grad():
    linear = nn.Linear(1, 1)
    linear.eval()
    print(linear.weight.requires_grad) #true

如果您确实打算为权重参数关闭requires_grad,也可以使用以下方法进行操作:

linear.weight.requires_grad_(False)

linear.weight.requires_grad = False

所以您的代码可能会变成这样:

with torch.no_grad():
    linear = nn.Linear(1, 1)
    linear.weight.requires_grad_(False)
    linear.eval()
    print(linear.weight.requires_grad)

如果您打算为模块中的所有参数切换到require_grad:

l = nn.Linear(1, 1)
for _ in l.parameters():
    _.requires_grad_(False)
    print(_)