state_dict中的意外密钥:“模型”,“选择”

时间:2019-03-07 15:12:04

标签: python deep-learning pytorch fast-ai

我目前正在使用fast.ai来训练图像分类器模型。

data = ImageDataBunch.single_from_classes(path, classes, ds_tfms=get_transforms(), size=224).normalize(imagenet_stats)
learner = cnn_learner(data, models.resnet34)

learner.model.load_state_dict(
    torch.load('stage-2.pth', map_location="cpu")
)

结果为:

  

torch.load('stage-2.pth',map_location =“ cpu”)文件   “ /usr/local/lib/python3.6/site-packages/torch/nn/modules/module.py”,   在load_state_dict中的第769行       self。名称,“ \ n \ t”。join(error_msgs)))RuntimeError:加载state_dict序列时出错:

     

...

     

state_dict中的意外密钥:“模型”,“选择”。

我在SO中环顾四周,并尝试使用以下解决方案:

# original saved file with DataParallel
state_dict = torch.load('stage-2.pth', map_location="cpu")
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`
    new_state_dict[name] = v
# load params
learner.model.load_state_dict(new_state_dict)

结果为:

  

RuntimeError:为顺序加载state_dict时发生错误:

     

state_dict中出现意外的键:“”。

我正在使用Google Colab训练我的模型,然后将训练后的模型移植到docker中,并尝试在本地服务器中进行托管。

可能是什么问题?可能是pytorch的不同版本导致模型不匹配吗?

在我的docker配置中:

# Install pytorch and fastai
RUN pip install torch_nightly -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
RUN pip install fastai

我的Colab使用以下工具时:

 !curl -s https://course.fast.ai/setup/colab | bash

1 个答案:

答案 0 :(得分:2)

我的强烈猜测是stage-2.pth包含两个顶级项:模型本身(权重)和用于训练模型的优化器的最终状态。要仅加载模型,只需要前者即可。假设事情是以惯用的PyTorch方式完成的,我会尝试

learner.model.load_state_dict(
    torch.load('stage-2.pth', map_location="cpu")['model']
)

更新:在应用了我的第一轮建议之后,很明显,您正在加载的保存点创建与正在加载的模型不同(可能配置不同?)。如您在pastebin中所见,保存点包含模型中不存在的一些额外图层的权重,例如bn3downsample等。

  

“ 0.4.0.bn3.running_var”,“ 0.4.0.bn3.num_batches_tracked”,“ 0.4.0.downsample.0.weight”

同时还有其他一些键名匹配,但是张量的形状不同。

  

0.5.0.downsample.0.weight的大小不匹配:从检查点复制形状为torch.Size([512,256,1,1])的参数,当前模型中的形状为torch.Size([128 ,64,1,1])。

我看到一个模式,您一直尝试加载形状为[2^(x+1), 2^x, 1, 1]的参数来代替[2^(x), 2^(x-1), 1, 1]。也许您正在尝试加载不同深度的模型(例如,加载vgg-11的vgg-16权重?)。无论哪种方式,您都需要弄清楚用于创建保存点的确切体系结构,然后在加载保存点之前重新创建它。

PS。如果不确定,保存点将包含模型权重,形状和(自动生成的)名称。它们不包含体系结构本身的完整规范-您需要确保自己正在调用model.load_state_dict,其中model与所使用的体系结构完全相同创建保存点。否则,您的体重名称可能会不匹配。