我试图从tensorflow开源模型运行textum模型, 在seq2seq_attention.py中,他们使用Supervisor来管理保存模型, 问题是在运行应用程序之后,管理员通过creatin检查点和图形等启动但是它不会在60秒后保存模型作为给定的参数,执行下一次保存需要几个小时,我试图删除global_step变量仍然是同样的问题,每次我停止训练,我必须从一开始就恢复(avg_loss)。谁能告诉我解决方案是什么?
给出的代码是:
def _Train(model, data_batcher):
"""Runs model training."""
with tf.device('/cpu:0'):
model.build_graph()
saver = tf.train.Saver()
# Train dir is different from log_root to avoid summary directory
# conflict with Supervisor.
summary_writer = tf.summary.FileWriter(FLAGS.train_dir)
sv = tf.train.Supervisor(logdir=FLAGS.log_root,
is_chief=True,
saver=saver,
summary_op=None,
save_summaries_secs=60,
save_model_secs=60,
global_step=model.global_step)
sess = sv.prepare_or_wait_for_session(config=tf.ConfigProto(
allow_soft_placement=True))
running_avg_loss = 0
step = 0
while not sv.should_stop() and step < FLAGS.max_run_steps:
(article_batch, abstract_batch, targets, article_lens, abstract_lens,
loss_weights, _, _) = data_batcher.NextBatch()
(_, summaries, loss, train_step) = model.run_train_step(
sess, article_batch, abstract_batch, targets, article_lens,
abstract_lens, loss_weights)
summary_writer.add_summary(summaries, train_step)
running_avg_loss = _RunningAvgLoss(
running_avg_loss, loss, summary_writer, train_step)
step += 1
if step % 100 == 0:
summary_writer.flush()
sv.Stop()
return running_avg_loss
答案 0 :(得分:0)
您是否尝试在实例化保护程序时指定保存之间的持续时间? 我的意思是(每15分钟保存一次模型):
saver = tf.train.Saver(keep_checkpoint_every_n_hours=0.25)