如何在使用tf.train.MonitoredTrainingSession

时间:2018-01-19 10:04:18

标签: python tensorflow machine-learning save

当我们在Saver.save中指定global_step时,它会将global_step存储为检查点后缀。

# save the checkpoint
saver = tf.train.Saver()
saver.save(session, checkpoints_path, global_step)

我们可以恢复检查点并获取检查点中存储的最后一个全局步骤,如下所示:

# restore the checkpoint and obtain the global step
saver.restore(session, ckpt.model_checkpoint_path)
...
_, gstep = session.run([optimizer, global_step], feed_dict=feed_dict_train)

如果我们使用tf.train.MonitoredTrainingSession,将全局步骤保存到检查点并获取gstep的等效方法是什么?

编辑1

根据Maxim的建议,我在global_step之前创建了tf.train.MonitoredTrainingSession变量,并添加了CheckpointSaverHook这样的内容:

global_step = tf.train.get_or_create_global_step()
save_checkpoint_hook = tf.train.CheckpointSaverHook(checkpoint_dir=checkpoints_abs_path,
                                                    save_steps=5,
                                                    checkpoint_basename=(checkpoints_prefix + ".ckpt"))

with tf.train.MonitoredTrainingSession(master=server.target,
                                       is_chief=is_chief,                     
                                       hooks=[sync_replicas_hook, save_checkpoint_hook],
                                       config=config) as session:

    _, gstep = session.run([optimizer, global_step], feed_dict=feed_dict_train)
    print("current global step=" + str(gstep))

我可以看到它生成的检查点文件与Saver.saver类似。但是,它无法从检查点检索全局步骤。请告知我该如何解决这个问题?

1 个答案:

答案 0 :(得分:3)

您可以通过tf.train.get_global_step()tf.train.get_or_create_global_step()功能获取当前的全局步骤。应该在训练开始前调用后者。

对于受监控的会话,将tf.train.CheckpointSaverHook添加到hooks# File upload settings file { folder = {$plugin.tx_powermail.settings.misc.uploadFolder} size = {$plugin.tx_powermail.settings.misc.uploadSize} extension = {$plugin.tx_powermail.settings.misc.uploadFileExtensions} randomizeFileName = 1 } 在内部使用定义的全局步长张量,在每N步后保存模型。