从检查点加载模型后,验证准确性下降

时间:2019-11-19 01:35:50

标签: python pytorch

我使用此代码在每个时期保存我的VAE模型:

torch.save({'epoch' : epoch,
            'encoder' : encoder.state_dict(),
            'decoder' : decoder.state_dict(),
            'property_predictor_model' : property_predictor_model.state_dict(),
            'train_ids' : train_ids,
            'valid_ids' : valid_ids,
            'validationQuality' : validationQuality
}, '{}/model_checkpoint{}.pt'.format(out_dir, epoch))

当我重新开始训练时,我使用以下代码初始化模型并加载以前的状态:

model_encode = VAE_encode(**encoder_parameter).to(device)
model_decode = VAE_decode(**decoder_parameter).to(device)
model_prop_predict = property_predictor_model(**prop_pred_parameter).to(device)

model_encode.load_state_dict(checkpoint['encoder'])
model_decode.load_state_dict(checkpoint['decoder'])
model_prop_predict.load_state_dict(checkpoint['property_predictor_model'])

保存在检查点中的Validation准确度约为60%,但是在加载模型后立即运行验证会返回<10%的准确度。训练就像我只是随机初始化模型一样继续进行。我加载状态前后的验证准确性不同,因此加载的是正确的东西,而不是正确的东西。任何想法为什么会这样。

有趣的笔记: 如果我使用调试器并在保存检查点后立即加载检查点,则可以保持验证的准确性,但是如果我取消脚本并加载相同的检查点,则准确性再次为<10%。

0 个答案:

没有答案