我使用pytorch 0.4.1 (GPU)训练了DENSENET161模型,在测试环境中,我必须将其加载到pytorch版本 0.4.0 (CPU)中。我已经在使用model.cpu()
但是当我加载静态字典model.load_state_dict(checkpoint['state_dict'])
我遇到以下错误:
RuntimeError:加载DenseNet的state_dict时出错:意外 state_dict中的键:“ features.norm0.num_batches_tracked”, “ features.denseblock1.denselayer1.norm1.num_batches_tracked”, “ features.denseblock1.denselayer1.norm2.num_batches_tracked”, “ features.denseblock1.denselayer2.norm1.num_batches_tracked”,...
答案 0 :(得分:1)
这似乎是由于PyTorch 0.4.1和0.4之间在规范化层的实现上的差异而引起的-前者跟踪一些称为num_batches_tracked
的状态变量,而pytorch 0.4则不会。假设只有意外的键而没有丢失的键(由于您已经裁剪了错误消息,所以我不能肯定地说出这些键),您可以删除多余的键,并希望模型能够加载。因此尝试
model_dict = checkpoint['state_dict']
filtered = {
k: v for k, v in model_dict.items() if 'num_batches_tracked' not in k
}
model.load_state_dict(filtered)
请注意,除了您在此处看到的内容之外,标准化的内部可能有所更改,因此,即使此修复程序抑制了该异常,该模型仍可能会默默地行为异常。