如何加载预训练的chainer npz参数文件但是修改了一些图层?

时间:2017-08-09 11:10:11

标签: chainer

我已经以chainer可加载的npz文件格式预先训练了一个VGG网络,但是在最后一层添加了一个新的FC层,我修改了最后一层输出class_number。 我已经修改了图层名称,以便将chainer可加载文件用于其他未更改的图层。 但我失败了。

Traceback (most recent call last):
  File "chainercv/trainer/train.py", line 194, in <module>
    main()
  File "chainercv/trainer/train.py", line 85, in main
    mean_file=args.mean)  # 可改为/home/machen/face_expr/result/snapshot_model.npz
  File "/home/machen/face_expr/chainercv/links/model/faster_rcnn/faster_rcnn_vgg.py", line 131, in __init__
    chainer.serializers.load_npz(pretrained_model, self)
  File "/usr/local/anaconda3/lib/python3.6/site-packages/chainer-3.0.0a1-py3.6.egg/chainer/serializers/npz.py", line 140, in load_npz
    d.load(obj)
  File "/usr/local/anaconda3/lib/python3.6/site-packages/chainer-3.0.0a1-py3.6.egg/chainer/serializer.py", line 82, in load
    obj.serialize(self)
  File "/usr/local/anaconda3/lib/python3.6/site-packages/chainer-3.0.0a1-py3.6.egg/chainer/link.py", line 794, in serialize
    d[name].serialize(serializer[name])
  File "/usr/local/anaconda3/lib/python3.6/site-packages/chainer-3.0.0a1-py3.6.egg/chainer/link.py", line 794, in serialize
    d[name].serialize(serializer[name])
  File "/usr/local/anaconda3/lib/python3.6/site-packages/chainer-3.0.0a1-py3.6.egg/chainer/link.py", line 550, in serialize
    data = serializer(name, param.data)
  File "/usr/local/anaconda3/lib/python3.6/site-packages/chainer-3.0.0a1-py3.6.egg/chainer/serializers/npz.py", line 106, in __call__
    dataset = self.npz[key]
  File "/usr/local/anaconda3/lib/python3.6/site-packages/numpy/lib/npyio.py", line 237, in __getitem__
    raise KeyError("%s is not a file in the archive" % key)
KeyError: 'head/score_mod/W is not a file in the archive'

1 个答案:

答案 0 :(得分:1)

我制作了自定义类来覆盖这样的一些层。 您可以使用init标志

控制预训练模型的加载时间
class MyRes(chainer.Chain):
def __init__(self, path=default_path, init=False):
    super(MyRes, self).__init__(
        c1 = L.Convolution2D(None, 64, 7, 2, 3),
        resnet = ResNet50Layers(pretrained_model=None),
    )   
    if not init:
        serializers.load_npz(path, self.resnet)
    self.resnet.conv1 = self.c1

当你开始训练时,你可以简单地添加模型的路径

res = MyRes(path=pretrained_model_path)

当你加载训练有素的MyModel模型时,像这样设置init标志

res = MyRes(init=True)
serializers.load_npz(saved_myres, self.resnet)