如何在saver.restore之后检索最后一个global_step

时间:2016-04-14 09:39:42

标签: tensorflow

我们可以保存一个检查点

saver = tf.train.Saver()
saver.save(sess, FLAGS.train_dir, global_step=step)

然后,我可以恢复所有变量:

saver.restore(sess, FLAGS.train_dir)

我想得到' golbal_step'我打电话给'saver.save'这样我就可以根据最后的global_step继续训练。

有没有办法得到它?似乎CheckpointState不包含这些信息。

message CheckpointState {
  // Path to the most-recent model checkpoint.
  string model_checkpoint_path = 1;

  // Paths to all not-yet-deleted model checkpoints, sorted from oldest to
  // newest.
  // Note that the value of model_checkpoint_path should be the last item in
  // this list.
  repeated string all_model_checkpoint_paths = 2;
}

Tensorflow get the global_step when restoring checkpoints一样,我可以引入一个新的TF变量,但如果我能在不添加新变量的情况下完成它,那将是最好的。有什么办法吗?

2 个答案:

答案 0 :(得分:1)

我认为一个简单的sess.run(global_step)应该返回值。

## Create and Save global_step 
global_step = tf.Variable(0, trainable=False)
train_step = tf.train.AdamOptimizer(...).minimize(..., global_step=global_step, ...) 
...
saver = tf.train.Saver() # var_list is None: defaults to the list of all saveable objects.

## Restore global_step 
sess.run(tf.global_variables_initializer())
...
ckpt = tf.train.get_checkpoint_state(FilePath_checkpoints)
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess, ckpt.model_checkpoint_path)
        last_global_step = sess.run(global_step)

答案 1 :(得分:0)

你可以这样做

with tf.Session() as sess:
    predict_top_5 = tf.nn.top_k(scores, k=5)
    label_top_5 = tf.nn.top_k(input_y, k=5)
    ckpt = tf.train.get_checkpoint_state('models')
    if ckpt and ckpt.model_checkpoint_path:
        saver.restore(sess,ckpt.model_checkpoint_path)
        global_step = int(ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1])