TensorFlow:Saver有5个型号限制

时间:2016-08-08 19:44:05

标签: python machine-learning tensorflow

我想为我的实验保存多个模型,但我注意到tf.train.Saver()构造函数无法保存超过5个模型。这是一个简单的代码:

import tensorflow as tf 

x = tf.Variable(tf.zeros([1]))
saver = tf.train.Saver()
sess = tf.Session()

for i in range(10):
  sess.run(tf.initialize_all_variables())
  saver.save( sess, '/home/eneskocabey/Desktop/model' + str(i) )

当我运行此代码时,我在桌面上只看到了5个型号。为什么是这样?如何使用相同的tf.train.Saver()构造函数保存5个以上的模型?

2 个答案:

答案 0 :(得分:23)

tf.train.Saver() constructor采用名为max_to_keep的可选参数,默认保留模型的5个最新检查点。要保存更多模型,只需为该参数指定一个值:

import tensorflow as tf 

x = tf.Variable(tf.zeros([1]))
saver = tf.train.Saver(max_to_keep=10)
sess = tf.Session()

for i in range(10):
  sess.run(tf.initialize_all_variables())
  saver.save(sess, '/home/eneskocabey/Desktop/model' + str(i))

要保留所有检查点,请将参数max_to_keep=None传递给保护程序构造函数。

答案 1 :(得分:2)

  1. 如果您使用自己的tf.Session()进行培训:

为了保留中间检查点而不是最后5个,您需要在tf.train.Saver()中更改2个参数:

  • max_to_keep-表示要保留的最近检查点文件的最大数量。创建新文件时,将删除旧文件。如果为None或0,则不会从文件系统中删除任何检查点,但是只有最后一个检查点保留在检查点文件中。默认值为5(即,保留5个最新的检查点文件。)
  • keep_checkpoint_every_n_hours-除了保留最新的max_to_keep检查点文件之外,您可能还希望每N个小时的训练保留一个检查点文件。如果您以后要分析模型在长时间的培训中的进展情况,这将很有用。例如,传递keep_checkpoint_every_n_hours = 2可以确保每2个小时的训练保留一个检查点文件。默认值为10,000小时有效地禁用了该功能。

因此,如果执行以下操作,则将每2小时存储一个检查点,并且如果保存的检查点总数达到10,则最早的检查点将被删除,而新的检查点将被替换:

saver = tf.train.Saver(max_to_keep=10, keep_checkpoint_every_n_hours=2)
  1. 如果使用tf.estimator.Estimator(),则检查点的保存由Estimator本身完成。因此,您需要为它传递一个带有以下某些参数的tf.estimator.RunConfig()

    • keep_checkpoint_max-要保留的最近检查点文件的最大数量。创建新文件时,将删除旧文件。如果为None或0,则保留所有检查点文件。默认值为5(即,保留5个最新的检查点文件。)
    • save_checkpoints_steps-每隔这么多个步骤保存检查点。不能用save_checkpoints_secs指定。
    • save_checkpoints_secs-每隔几秒钟保存检查点。不能用save_checkpoints_steps指定。如果未在构造函数中同时设置save_checkpoints_stepssave_checkpoints_secs,则默认为600秒。如果save_checkpoints_stepssave_checkpoints_secs均为“无”,则禁用检查点。

因此,如果执行以下操作,则将每100次迭代存储一个检查点,并且如果保存的检查点总数达到10个,则将删除最早的检查点,并用一个新的检查点替换该检查点:

run_config = tf.estimator.RunConfig()
run_config = run_config.replace(keep_checkpoint_max=10, 
    save_checkpoints_steps=100)
classifier = tf.estimator.Estimator(
    model_fn=model_fn, model_dir=model_dir, config=run_config)