我如何仅使用pytorch更新网络中的某些特定张量?

时间:2018-08-29 04:13:53

标签: image-processing machine-learning deep-learning conv-neural-network pytorch

例如,我只想在前10个时期中更新Resnet中的所有cnn权重,然后冻结其他权重。
从第11个时代开始,我想更改以更新整个模型。
我如何实现目标?

2 个答案:

答案 0 :(得分:5)

您可以为每个参数组设置学习率(以及其他一些元参数)。您只需要根据需要对参数进行分组即可。
例如,为转换层设置不同的学习率:

import torch
import itertools
from torch import nn

conv_params = itertools.chain.from_iterable([m.parameters() for m in model.children()
                                             if isinstance(m, nn.Conv2d)])
other_params = itertools.chain.from_iterable([m.parameters() for m in model.children()
                                              if not isinstance(m, nn.Conv2d)]) 
optimizer = torch.optim.SGD([{'params': other_params},
                             {'params': conv_params, 'lr': 0}],  # set init lr to 0
                            lr=lr_for_model)

您以后可以访问优化器param_groups并修改学习率。

有关更多信息,请参见per-parameter options

答案 1 :(得分:0)

非常简单,因为PYTORCH可以即时重新创建计算图。

for p in resnet.parameters():
    p.requires_grad = False # this will freeze the module from training suppose that resnet is one of your module

如果您有多个模块,只需在其上循环即可。然后在10个时代之后,您只需致电

for p in network.parameters():
    p.requires_grad = True # suppose your whole network is the 'network' module