PyTorch Pretrained VGG19 KeyError

时间:2017-06-19 12:18:47

标签: python machine-learning deep-learning convolution pytorch

我正在对前10层VGG19网进行微调以从图像中提取特征。但我得到以下错误,我无法找到解决方法:

Traceback (most recent call last):
  File "TODO-train_from_scratch.py", line 390, in <module>
    main()
  File "TODO-train_from_scratch.py", line 199, in main
    model.load_state_dict(weights_load)
  File "/usr/local/lib/python2.7/dist-packages/torch/nn/modules/module.py", line 339, in load_state_dict
    raise KeyError('missing keys in state_dict: "{}"'.format(missing)) 

来自培训代码的相应片段是:

# create model
vgg19 = models.vgg19(pretrained = True)

vgg19_state_dict = vgg19.state_dict()

vgg19_keys = vgg19_state_dict.keys()    

model = get_model()

weights_load = {}   

for i in range(20):
    weights_load[model.state_dict().keys()[i]] = vgg19_state_dict[vgg19_keys[i]]

model.load_state_dict(weights_load)
model = torch.nn.DataParallel(model).cuda()

0 个答案:

没有答案