在Tensorflow中,是否有一种简单的方法可以在模型检查点发生时注册回调函数?

时间:2017-05-30 18:28:02

标签: tensorflow

我正在查看Tensorflow实现的一个名为Skip-thought Vector models的文本编码器模型的代码:https://github.com/tensorflow/models/tree/master/skip_thoughts

训练脚本中的代码包含以下内容:

saver = tf.train.Saver()

tf.contrib.slim.learning.train(
  train_op=train_tensor,
  logdir=FLAGS.train_dir,
  graph=g,
  global_step=model.global_step,
  number_of_steps=training_config.number_of_steps,
  save_summaries_secs=training_config.save_summaries_secs,
  saver=saver,
  save_interval_secs=training_config.save_model_secs)

显然,模型检查点每隔training_config.save_model_secs秒保存一次。

我想知道是否有办法注册某种类型的回调函数,每次都会在模型检查点发生后调用它。具体来说,我想将模型检查点移动/复制到其他一些网络位置。

1 个答案:

答案 0 :(得分:0)

CheckpointSaverListener(请参阅code)是一种可行的方法,但需要使用MonitoredTrainingSession而不是依赖苗条的api,因此您需要重新实现一些slim.train方法的逻辑。

# Class example from TensorFlow link above
class ExampleCheckpointSaverListerner(CheckpointSaverListener):
    def after_save(self, session, global_step_value):
        print('Done writing checkpoint.')
    ...

# Pseudo-code to illustrate how to use it
your_hooks = [ExampleCheckpointSaverListerner()]
step = 0
with tf.train.MonitoredTrainingSession(checkpoint_dir=FLAGS.train_dir,
                                       chief_only_hooks=your_hooks) as sess:
    # Your training loop
    while step < num_loop:
        _, step = sess.run([train_tensor, model.global_step], ...)