在pytorch中,如何才能将预训练的权重转换为orderdict格式并将其重新转换为模型?

时间:2018-11-29 20:34:08

标签: python machine-learning pytorch transfer-learning

我已经下载了pth.tar格式的预训练模型,一旦在其上调用torch.load(),我就会看到它是orderdict格式,带有相应的图层名称及其权重。然后,我尝试用字典理解简单地循环键和值以创建nn.ParameterDict(),但是由于orderdict的键中的命名约定,例如layer.1.weights,我会遇到此错误

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-91-578f5b12a2c6> in <module>()
----> 1 nn.ParameterDict(checkpoint)

D:\Anaconda\lib\site-packages\torch\nn\modules\container.py in __init__(self, parameters)
    425         super(ParameterDict, self).__init__()
    426         if parameters is not None:
--> 427             self.update(parameters)
    428 
    429     def __getitem__(self, key):

D:\Anaconda\lib\site-packages\torch\nn\modules\container.py in update(self, parameters)
    492             if isinstance(parameters, OrderedDict):
    493                 for key, parameter in parameters.items():
--> 494                     self[key] = parameter
    495             else:
    496                 for key, parameter in sorted(parameters.items()):

D:\Anaconda\lib\site-packages\torch\nn\modules\container.py in __setitem__(self, key, parameter)
    431 
    432     def __setitem__(self, key, parameter):
--> 433         self.register_parameter(key, parameter)
    434 
    435     def __delitem__(self, key):

D:\Anaconda\lib\site-packages\torch\nn\modules\module.py in register_parameter(self, name, param)
    136                             "Got {}".format(torch.typename(name)))
    137         elif '.' in name:
--> 138             raise KeyError("parameter name can't contain \".\"")
    139         elif name == '':
    140             raise KeyError("parameter name can't be empty string \"\"")

KeyError: 'parameter name can\'t contain "."'

因此,最后要进行此运行,我应该在layer.1.weights转换期间将layer_1_weights的图层重命名为nn.ParameterDict吗?有关系吗?我还搜索了load_state_dict,从我的基本理解出发,您需要预先定义模型类,然后再将权重加载到其中,在这种情况下,我没有模型类,我正在尝试使用此orderdict文件中的信息来构建模型类。那么解决这个问题的正确方法是什么?

0 个答案:

没有答案