如何在张量流中保存迭代模型和最佳模型?

时间:2018-01-18 14:51:51

标签: python tensorflow machine-learning cross-validation

我试图保存模型的迭代检查点,以及在独立验证数据集上获得最佳分数的模型。然而,我的检查点覆盖了我最好的模型。实际上,我使用的是:

saver = tf.train.Saver()

with tf.Session() as sess:
    for epoch in range(20):
        # Train model [...]

        # and save a checkpoint 
        saver.save(sess, "iter", global_step=epoch)

        if best_validiation_acc < last_validation_acc:
            saver.save(sess, "best_model")

如何让我的最佳模型不被我的迭代保存覆盖?

1 个答案:

答案 0 :(得分:3)

原因是您为两者使用相同的tf.train.Saver,因此无论您如何命名,它都会记住最后max_to_keep=5个检查点文件。

最简单的解决方案是设置max_to_keep=None,这将强制保护程序保留所有检查点,而不是覆盖任何。但是,您可能更愿意至少覆盖迭代检查点。在这种情况下的解决方案是:

iter_saver = tf.train.Saver(max_to_keep=3)  # keep 3 last iterations
best_saver = tf.train.Saver(max_to_keep=5)  # keep 5 last best models

with tf.Session() as sess:
    for epoch in range(20):
        # Train model [...]

        # and save a checkpoint 
        iter_saver.save(sess, "iter/model", global_step=epoch)

        if best_validiation_acc < last_validation_acc:
            best_saver.save(sess, "best/model")

我还使用不同的目录,以便checkpoint文件不会发生冲突。