我在Pytorch有自己的网络。它首先针对二进制分类器(2个类)进行了训练。历经10k次后,我得到的训练权重为10000_model.pth
。现在,我想使用同一网络将模型用于4类分类器问题。因此,我想将二元分类器中所有训练过的权重转移到4类问题,而没有会随机初始化的lass层。我该怎么办?这是我的模特
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.conv_classify= nn.Conv2d(50, 2, 1, 1, bias=True) # number of class
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv_classify(x))
return x
这就是我所做的
model = Net ()
checkpoint_dict = torch.load('10000_model.pth')
pretrained_dict = checkpoint_dict['state_dict']
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
现在,我必须按名称手动删除pretrained_dict。
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
pretrained_dict.pop('conv_classify.weight', None)
pretrained_dict.pop('conv_classify.bias', None)
这意味着pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
不执行任何操作。
怎么了?我正在使用pytorch 1.0。谢谢
答案 0 :(得分:2)
两个网络都具有相同的层,因此state_dict
中的密钥也相同,
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
什么都不做。两者之间的区别在于张量(形状)而不是名称。换句话说,您可以通过[v.shape for v in model.state_dict().values()]
而不是model.state_dict().keys()
来区分两者。您的“替代方法”方法是正确的。如果您想减少一些手动操作,我会使用
merged_dict = {}
for key in model_dict.keys():
if 'conv_classify' in key: # or perhaps a more complex criterion
merged_dict[key] = model_dict[key]
else:
merged_dict[key] = pretrained_dict[key]