我目前正在使用fast.ai来训练图像分类器模型。
data = ImageDataBunch.single_from_classes(path, classes, ds_tfms=get_transforms(), size=224).normalize(imagenet_stats)
learner = cnn_learner(data, models.resnet34)
learner.model.load_state_dict(
torch.load('stage-2.pth', map_location="cpu")
)
结果为:
torch.load('stage-2.pth',map_location =“ cpu”)文件 “ /usr/local/lib/python3.6/site-packages/torch/nn/modules/module.py”, 在load_state_dict中的第769行 self。类。名称,“ \ n \ t”。join(error_msgs)))RuntimeError:加载state_dict序列时出错:
...
state_dict中的意外密钥:“模型”,“选择”。
我在SO中环顾四周,并尝试使用以下解决方案:
# original saved file with DataParallel
state_dict = torch.load('stage-2.pth', map_location="cpu")
# create new OrderedDict that does not contain `module.`
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# load params
learner.model.load_state_dict(new_state_dict)
结果为:
RuntimeError:为顺序加载state_dict时发生错误:
state_dict中出现意外的键:“”。
我正在使用Google Colab训练我的模型,然后将训练后的模型移植到docker中,并尝试在本地服务器中进行托管。
可能是什么问题?可能是pytorch的不同版本导致模型不匹配吗?
在我的docker配置中:
# Install pytorch and fastai
RUN pip install torch_nightly -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
RUN pip install fastai
我的Colab使用以下工具时:
!curl -s https://course.fast.ai/setup/colab | bash
答案 0 :(得分:2)
我的强烈猜测是stage-2.pth
包含两个顶级项:模型本身(权重)和用于训练模型的优化器的最终状态。要仅加载模型,只需要前者即可。假设事情是以惯用的PyTorch方式完成的,我会尝试
learner.model.load_state_dict(
torch.load('stage-2.pth', map_location="cpu")['model']
)
更新:在应用了我的第一轮建议之后,很明显,您正在加载的保存点创建与正在加载的模型不同(可能配置不同?)。如您在pastebin中所见,保存点包含模型中不存在的一些额外图层的权重,例如bn3
,downsample
等。
“ 0.4.0.bn3.running_var”,“ 0.4.0.bn3.num_batches_tracked”,“ 0.4.0.downsample.0.weight”
同时还有其他一些键名匹配,但是张量的形状不同。
0.5.0.downsample.0.weight的大小不匹配:从检查点复制形状为torch.Size([512,256,1,1])的参数,当前模型中的形状为torch.Size([128 ,64,1,1])。
我看到一个模式,您一直尝试加载形状为[2^(x+1), 2^x, 1, 1]
的参数来代替[2^(x), 2^(x-1), 1, 1]
。也许您正在尝试加载不同深度的模型(例如,加载vgg-11的vgg-16权重?)。无论哪种方式,您都需要弄清楚用于创建保存点的确切体系结构,然后在加载保存点之前重新创建它。
PS。如果不确定,保存点将包含模型权重,形状和(自动生成的)名称。它们不不包含体系结构本身的完整规范-您需要确保自己正在调用model.load_state_dict
,其中model
与所使用的体系结构完全相同创建保存点。否则,您的体重名称可能会不匹配。