我正在用渴望模式下的带有tensorflow的a3c的基于过程的实现。每次梯度更新后,我的通用模型都会将其参数作为检查点写入文件夹。然后,工作人员通过从该文件夹加载最后的检查点来更新其参数。但是,有一个问题。
通常,当工作进程正在从文件夹中读取最后一个可用的检查点时,主网络会将新的检查点写入文件夹中,有时会擦除工作进程正在读取的检查点。一个简单的解决方案是提高要保留的最大检查点数。但是,tfe.Checkpoint和tfe.Saver没有参数来选择要保留的最大值。
有没有办法做到这一点?
答案 0 :(得分:0)
对于tf.train.Saver,您可以指定max_to_keep
:
tf.train.Saver(max_to_keep = 10)
{p>和max_to_keep
似乎同时出现在fte.Saver和tf.training.Saver中。
我还没有尝试过。
答案 1 :(得分:0)
看来,建议删除检查点的方法是使用CheckpointManager。
import tensorflow as tf
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
manager = tf.contrib.checkpoint.CheckpointManager(
checkpoint, directory="/tmp/model", max_to_keep=5)
status = checkpoint.restore(manager.latest_checkpoint)
while True:
# train
manager.save()