将pytorch模型从0.4.1加载到0.4.0?

时间:2018-12-07 23:44:21

标签: python deep-learning pytorch

我使用pytorch 0.4.1 (GPU)训练了DENSENET161模型,在测试环境中,我必须将其加载到pytorch版本 0.4.0 (CPU)中。我已经在使用model.cpu() 但是当我加载静态字典model.load_state_dict(checkpoint['state_dict'])

我遇到以下错误:

  

RuntimeError:加载DenseNet的state_dict时出错:意外   state_dict中的键:“ features.norm0.num_batches_tracked”,   “ features.denseblock1.denselayer1.norm1.num_batches_tracked”,   “ features.denseblock1.denselayer1.norm2.num_batches_tracked”,   “ features.denseblock1.denselayer2.norm1.num_batches_tracked”,...

1 个答案:

答案 0 :(得分:1)

这似乎是由于PyTorch 0.4.1和0.4之间在规范化层的实现上的差异而引起的-前者跟踪一些称为num_batches_tracked的状态变量,而pytorch 0.4则不会。假设只有意外的键而没有丢失的键(由于您已经裁剪了错误消息,所以我不能肯定地说出这些键),您可以删除多余的键,并希望模型能够加载。因此尝试

model_dict = checkpoint['state_dict']
filtered = {
    k: v for k, v in model_dict.items() if 'num_batches_tracked' not in k
}
model.load_state_dict(filtered)

请注意,除了您在此处看到的内容之外,标准化的内部可能有所更改,因此,即使此修复程序抑制了该异常,该模型仍可能会默默地行为异常。