我已经下载了pth.tar格式的预训练模型,一旦在其上调用torch.load()
,我就会看到它是orderdict
格式,带有相应的图层名称及其权重。然后,我尝试用字典理解简单地循环键和值以创建nn.ParameterDict()
,但是由于orderdict
的键中的命名约定,例如layer.1.weights
,我会遇到此错误
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-91-578f5b12a2c6> in <module>()
----> 1 nn.ParameterDict(checkpoint)
D:\Anaconda\lib\site-packages\torch\nn\modules\container.py in __init__(self, parameters)
425 super(ParameterDict, self).__init__()
426 if parameters is not None:
--> 427 self.update(parameters)
428
429 def __getitem__(self, key):
D:\Anaconda\lib\site-packages\torch\nn\modules\container.py in update(self, parameters)
492 if isinstance(parameters, OrderedDict):
493 for key, parameter in parameters.items():
--> 494 self[key] = parameter
495 else:
496 for key, parameter in sorted(parameters.items()):
D:\Anaconda\lib\site-packages\torch\nn\modules\container.py in __setitem__(self, key, parameter)
431
432 def __setitem__(self, key, parameter):
--> 433 self.register_parameter(key, parameter)
434
435 def __delitem__(self, key):
D:\Anaconda\lib\site-packages\torch\nn\modules\module.py in register_parameter(self, name, param)
136 "Got {}".format(torch.typename(name)))
137 elif '.' in name:
--> 138 raise KeyError("parameter name can't contain \".\"")
139 elif name == '':
140 raise KeyError("parameter name can't be empty string \"\"")
KeyError: 'parameter name can\'t contain "."'
因此,最后要进行此运行,我应该在layer.1.weights
转换期间将layer_1_weights
的图层重命名为nn.ParameterDict
吗?有关系吗?我还搜索了load_state_dict
,从我的基本理解出发,您需要预先定义模型类,然后再将权重加载到其中,在这种情况下,我没有模型类,我正在尝试使用此orderdict
文件中的信息来构建模型类。那么解决这个问题的正确方法是什么?