使用TensorFlow的主管

时间:2017-04-24 17:46:00

标签: python machine-learning tensorflow deep-learning

我试图从tensorflow开源模型运行textum模型, 在seq2seq_attention.py中,他们使用Supervisor来管理保存模型, 问题是在运行应用程序之后,管理员通过creatin检查点和图形等启动但是它不会在60秒后保存模型作为给定的参数,执行下一次保存需要几个小时,我试图删除global_step变量仍然是同样的问题,每次我停止训练,我必须从一开始就恢复(avg_loss)。谁能告诉我解决方案是什么?

给出的代码是:

def _Train(model, data_batcher):
  """Runs model training."""
  with tf.device('/cpu:0'):
    model.build_graph()
    saver = tf.train.Saver()
    # Train dir is different from log_root to avoid summary directory
    # conflict with Supervisor.
    summary_writer = tf.summary.FileWriter(FLAGS.train_dir)
    sv = tf.train.Supervisor(logdir=FLAGS.log_root,
                             is_chief=True,
                             saver=saver,
                             summary_op=None,
                             save_summaries_secs=60,
                             save_model_secs=60,
                             global_step=model.global_step)
    sess = sv.prepare_or_wait_for_session(config=tf.ConfigProto(
        allow_soft_placement=True))
    running_avg_loss = 0
    step = 0
    while not sv.should_stop() and step < FLAGS.max_run_steps:
      (article_batch, abstract_batch, targets, article_lens, abstract_lens,
       loss_weights, _, _) = data_batcher.NextBatch()
      (_, summaries, loss, train_step) = model.run_train_step(
          sess, article_batch, abstract_batch, targets, article_lens,
          abstract_lens, loss_weights)
      summary_writer.add_summary(summaries, train_step)
      running_avg_loss = _RunningAvgLoss(
          running_avg_loss, loss, summary_writer, train_step)
      step += 1
      if step % 100 == 0:
        summary_writer.flush()
    sv.Stop()
    return running_avg_loss

1 个答案:

答案 0 :(得分:0)

您是否尝试在实例化保护程序时指定保存之间的持续时间? 我的意思是(每15分钟保存一次模型):

saver = tf.train.Saver(keep_checkpoint_every_n_hours=0.25)