Tensorflow使用`tf.train.MonitoredTrainingSession恢复`tf.Session`保存检查点

时间:2017-10-20 01:02:13

标签: session tensorflow checkpoint

我有使用tf.train.MonitoredTrainingSession训练CNN的代码。

当我创建新的tf.train.MonitoredTrainingSession时,我可以将checkpoint目录作为输入参数传递给会话,它会自动恢复它可以找到的最新保存的checkpoint。我可以设置hooks来训练直到某一步。例如,如果checkpoint的步骤为150,000,并且我想训练到200,000,我会将last_step放到200,000

只要使用checkpoint保存了最新的tf.train.MonitoredTrainingSession,上述过程就能完美运行。但是,如果我尝试恢复使用普通checkpoint保存的tf.Session,那么所有地狱都会破裂。它无法在图表中找到一些关键字。

培训完成了:

with tf.train.MonitoredTrainingSession(
    checkpoint_dir=FLAGS.retrain_dir,
    hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_training_steps),
           tf.train.NanTensorHook(loss),
           _LoggerHook()],
    config=tf.ConfigProto(
        log_device_placement=FLAGS.log_device_placement)) as mon_sess:
  while not mon_sess.should_stop():
    mon_sess.run(train_op)

如果checkpoint_dir属性的文件夹没有检查点,则会从头开始。如果它从之前的培训课程中保存了checkpoint,则会恢复最新的checkpoint并继续培训。

现在,我正在恢复最新的checkpoint并修改一些变量并保存它们:

saver = tf.train.Saver(variables_to_restore)

ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)

with tf.Session() as sess:
  if ckpt and ckpt.model_checkpoint_path:
    # Restores from checkpoint
    saver.restore(sess, ckpt.model_checkpoint_path)
    print(ckpt.model_checkpoint_path)
    restored_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
    FLAGS.max_training_steps = int(restored_step) + FLAGS.max_training_steps
  else:
    print('No checkpoint file found')
    return

  prune_convs(sess)
  saver.save(sess, FLAGS.retrain_dir+"model.ckpt-"+restored_step)

正如您所见,就在saver.save...之前,我正在修剪网络中的所有卷积层。无需描述完成的方式和原因。关键是网络实际上已被修改。然后我将网络保存到checkpoint

现在,如果我在已保存的已修改网络上部署测试,则测试工作正常。但是,当我尝试在保存的tf.train.MonitoredTrainingSession上运行checkpoint时,它会说:

  

检查点中找不到密钥conv1 / weight_loss / avg

另外,我注意到checkpoint保存的tf.Session的大小是checkpoint保存的tf.train.MonitoredTrainingSession的一半

我知道我做错了,有什么建议如何使这项工作?

1 个答案:

答案 0 :(得分:1)

我明白了。显然,tf.Saver不会从checkpoint恢复所有变量。我尝试立即恢复和保存,输出只有一半。

我使用tf.train.list_variables获取最新checkpoint的所有变量,然后将其转换为tf.Variable并从中创建dict。然后我将dict传递给tf.Saver,它恢复了我的所有变量。

接下来是initialize所有变量,然后修改权重。

现在它正在运作。