我正在查看Tensorflow实现的一个名为Skip-thought Vector models的文本编码器模型的代码:https://github.com/tensorflow/models/tree/master/skip_thoughts。
训练脚本中的代码包含以下内容:
saver = tf.train.Saver()
tf.contrib.slim.learning.train(
train_op=train_tensor,
logdir=FLAGS.train_dir,
graph=g,
global_step=model.global_step,
number_of_steps=training_config.number_of_steps,
save_summaries_secs=training_config.save_summaries_secs,
saver=saver,
save_interval_secs=training_config.save_model_secs)
显然,模型检查点每隔training_config.save_model_secs
秒保存一次。
我想知道是否有办法注册某种类型的回调函数,每次都会在模型检查点发生后调用它。具体来说,我想将模型检查点移动/复制到其他一些网络位置。
答案 0 :(得分:0)
CheckpointSaverListener
(请参阅code)是一种可行的方法,但需要使用MonitoredTrainingSession
而不是依赖苗条的api,因此您需要重新实现一些slim.train
方法的逻辑。
# Class example from TensorFlow link above
class ExampleCheckpointSaverListerner(CheckpointSaverListener):
def after_save(self, session, global_step_value):
print('Done writing checkpoint.')
...
# Pseudo-code to illustrate how to use it
your_hooks = [ExampleCheckpointSaverListerner()]
step = 0
with tf.train.MonitoredTrainingSession(checkpoint_dir=FLAGS.train_dir,
chief_only_hooks=your_hooks) as sess:
# Your training loop
while step < num_loop:
_, step = sess.run([train_tensor, model.global_step], ...)