我正在使用tensorflow并且已经使用tf.saver()
方法训练了一些模型并在每个纪元后保存它们。我能够很好地保存和加载模型,我正在以通常的方式做到这一点。
with tf.Graph().as_default(), tf.Session() as session:
initialiser = tf.random_normal_initializer(config.mean, config.std)
with tf.variable_scope("model",reuse=None, initializer=initialiser):
m = a2p(session, config, training=True)
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state(model_dir)
if ckpt and tf.gfile.Exists(ckpt.model_checkpoint_path)
saver.restore(session, ckpt.model_checkpoint_path)
...
for i in range(epochs):
runepoch()
save_path = saver.save(session, '%s.ckpt'%i)
我的代码设置为保存每个时期的模型,应该相应地标记。但是,我注意到,在十五个训练时期之后,我只有最后五个时期的检查点文件(10,11,12,13,14)。文档没有说明这一点,所以我不知道为什么会发生这种情况。
保护程序是否只允许保留五个检查点或我做错了什么?
有没有办法确保保留所有检查点?
答案 0 :(得分:8)
您可以通过设置默认为5的max_to_keep
参数来选择create your Saver
object时要保存的检查点数。
saver = tf.train.Saver(max_to_keep=10000)
答案 1 :(得分:2)
设置max_to_keep=None
实际上使Saver保留所有检查点。
例如,
saver = tf.train.Saver(max_to_keep=None)