我使用此代码在每个时期保存我的VAE模型:
torch.save({'epoch' : epoch,
'encoder' : encoder.state_dict(),
'decoder' : decoder.state_dict(),
'property_predictor_model' : property_predictor_model.state_dict(),
'train_ids' : train_ids,
'valid_ids' : valid_ids,
'validationQuality' : validationQuality
}, '{}/model_checkpoint{}.pt'.format(out_dir, epoch))
当我重新开始训练时,我使用以下代码初始化模型并加载以前的状态:
model_encode = VAE_encode(**encoder_parameter).to(device)
model_decode = VAE_decode(**decoder_parameter).to(device)
model_prop_predict = property_predictor_model(**prop_pred_parameter).to(device)
model_encode.load_state_dict(checkpoint['encoder'])
model_decode.load_state_dict(checkpoint['decoder'])
model_prop_predict.load_state_dict(checkpoint['property_predictor_model'])
保存在检查点中的Validation准确度约为60%,但是在加载模型后立即运行验证会返回<10%的准确度。训练就像我只是随机初始化模型一样继续进行。我加载状态前后的验证准确性不同,因此加载的是正确的东西,而不是正确的东西。任何想法为什么会这样。
有趣的笔记: 如果我使用调试器并在保存检查点后立即加载检查点,则可以保持验证的准确性,但是如果我取消脚本并加载相同的检查点,则准确性再次为<10%。