如何在Tensorflow中恢复检查点时获取global_step?

时间:2016-03-20 11:25:10

标签: tensorflow

我像这样保存会话状态:

self._saver = tf.saver()
self._saver.save(self._session, '/network', global_step=self._time)

当我稍后恢复时,我想获取我从中恢复的检查点的global_step的值。这是为了从中设置一些超参数。

执行此操作的hacky方法是运行并解析检查点目录中的文件名。但是,必须有更好的,内置的方法来做到这一点?

8 个答案:

答案 0 :(得分:24)

一般模式是使用global_step变量来跟踪步骤

global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)

然后你可以用

保存
saver.save(sess, save_path, global_step=global_step)

还原时,global_step的值也会恢复

答案 1 :(得分:5)

这有点像黑客,但其他答案对我来说根本不起作用

string duration = (starttime.HasValue && endtime.HasValue) 
    ? (endtime - starttime).ToString() 
    : "0";

更新9/2017

我不确定这是否因更新而开始工作,但以下方法似乎可以有效地使global_step更新并正确加载:

创建两个操作。一个用于保存global_step而另一个用于增加它:

ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 

#Extract from checkpoint filename
step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])

现在,在您的训练循环中,每次运行训练操作时都会运行增量操作。

    global_step = tf.Variable(0, trainable=False, name='global_step')
    increment_global_step = tf.assign_add(global_step,1,
                                            name = 'increment_global_step')

如果您想在任何时候将整数步长值检索为整数,请在加载模型后使用以下命令:

sess.run([train_op,increment_global_step],feed_dict=feed_dict)

这对于创建文件名或计算当前时期的内容非常有用,而不需要第二个tensorflow变量来保存该值。例如,计算加载时的当前纪元就像:

sess.run(global_step)

答案 2 :(得分:1)

我和Lawrence Du有同样的问题,我找不到通过恢复模型获得global_step的方法。所以我将his hack应用于我正在使用的the inception v3 training code in the Tensorflow/models github repo。下面的代码还包含与pretrained_model_checkpoint_path相关的修补程序。

如果您有更好的解决方案,或者知道我缺少什么,请发表评论!

无论如何,这段代码对我有用:

...

# When not restoring start at 0
last_step = 0
if FLAGS.pretrained_model_checkpoint_path:
    # A model consists of three files, use the base name of the model in
    # the checkpoint path. E.g. my-model-path/model.ckpt-291500
    #
    # Because we need to give the base name you can't assert (will always fail)
    # assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)

    variables_to_restore = tf.get_collection(
        slim.variables.VARIABLES_TO_RESTORE)
    restorer = tf.train.Saver(variables_to_restore)
    restorer.restore(sess, FLAGS.pretrained_model_checkpoint_path)
    print('%s: Pre-trained model restored from %s' %
          (datetime.now(), FLAGS.pretrained_model_checkpoint_path))

    # HACK : global step is not restored for some unknown reason
    last_step = int(os.path.basename(FLAGS.pretrained_model_checkpoint_path).split('-')[1])

    # assign to global step
    sess.run(global_step.assign(last_step))

...

for step in range(last_step + 1, FLAGS.max_steps):

  ...

答案 3 :(得分:1)

TL; DR

作为tensorflow变量(将在会话中评估)

global_step = tf.train.get_or_create_global_step()
# use global_step variable to calculate your hyperparameter 
# this variable will be evaluated later in the session
saver = tf.train.Saver()
with tf.Session() as sess:
    # restore all variables from checkpoint
    saver.restore(sess, checkpoint_path)
    # than init table and local variables and start training/evaluation ...

或:作为numpy整数(没有任何会话):

reader = tf.train.NewCheckpointReader(absolute_checkpoint_path)
global_step = reader.get_tensor('global_step')


长答案

至少有两种方法可以从检查点检索全局变量。作为tensorflow变量或numpy整数。如果global_step的{​​{3}}方法中未提供Saver作为参数,则无法解析文件名。对于预训练的模型,请参见答案末尾的说明。

作为Tensorflow变量

如果您需要global_step变量来计算某些超参数,则可以使用save。这将返回一个tensorflow变量。由于该变量将在会话的稍后阶段进行评估,因此您只能使用tensorflow操作来计算您的超参数。因此,例如:max(global_step, 100)将不起作用。您必须使用等效的tensorflow tf.maximum(global_step, 100),可以在会话的稍后进行评估。

在会话中,您可以使用tf.train.get_or_create_global_step()

使用检查点来初始化全局step变量。
global_step = tf.train.get_or_create_global_step()
# use global_step variable to calculate your hyperparameter 
# this variable will be evaluated later in the session
hyper_parameter = tf.maximum(global_step, 100) 
saver = tf.train.Saver()
with tf.Session() as sess:
    # restore all variables from checkpoint
    saver.restore(sess, checkpoint_path)
    # than init table and local variables and start training/evaluation ...

    # for verification you can print the global step and your hyper parameter
    print(sess.run([global_step, hyper_parameter]))

或:作为numpy整数(无会话)

如果您需要全局step变量作为标量而不启动会话,则还可以直接从检查点文件读取此变量。您只需要一个NewCheckpointReader。由于旧版tensorflow版本中的saver.restore(sess, checkpoint_path),您应该将检查点文件的路径转换为绝对路径。使用阅读器,您可以将模型的所有张量作为numpy变量获取。 全局步骤变量的名称是定义为'global_step'的常量字符串bug

absolute_checkpoint_path = os.path.abspath(checkpoint_path)
reader = tf.train.NewCheckpointReader(absolute_checkpoint_path)
global_step = reader.get_tensor(tf.GraphKeys.GLOBAL_STEP)

注释预训练模型:在大多数在线可用的预训练模型中,全局步长重置为零。因此,这些模型可用于初始化模型参数以进行微调,而无需覆盖全局步骤。

答案 4 :(得分:0)

当前的0.10rc0版本似乎有所不同,不再有tf.saver()了。现在它是tf.train.Saver()。另外,save命令将信息添加到global_step的save_path文件名中,因此我们不能在同一个save_path上调用restore,因为那不是实际的保存文件。

我现在看到的最简单的方法是使用SessionManager以及这样的保护程序:

my_checkpoint_dir = "/tmp/checkpoint_dir"
# make a saver to use with SessionManager for restoring
saver = tf.train.Saver()
# Build an initialization operation to run below.
init = tf.initialize_all_variables()
# use a SessionManager to help with automatic variable restoration
sm = tf.train.SessionManager()
# try to find the latest checkpoint in my_checkpoint_dir, then create a session with that restored
# if no such checkpoint, then call the init_op after creating a new session
sess = sm.prepare_session("", init_op=init, saver=saver, checkpoint_dir=my_checkpoint_dir))

那就是它。现在你有一个会话要么从my_checkpoint_dir恢复(在调用之前确保该目录存在),或者如果那里没有检查点那么它会创建一个新会话并调用init_op来初始化你的变量

当你想要保存时,你只需保存到该目录中你想要的任何名称并传递global_step。这里是一个例子,我将循环变量保存在循环中作为global_step,所以它返回到那一点,如果你杀死程序并重新启动它,以便恢复检查点:

checkpoint_path = os.path.join(my_checkpoint_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=step)

这会在my_checkpoint_dir中创建文件,例如" model.ckpt-1000"其中1000是传入的global_step。如果它继续运行,那么你会更像" model.ckpt-2000"。上面的SessionManager在程序重启时选取最新的一个。 checkpoint_path可以是您想要的任何文件名,只要它在checkpoint_dir中即可。 save()将创建附加了global_step的文件(如上所示)。它还创建了一个"检查点"索引文件,这是SessionManager随后找到最新保存检查点的方式。

答案 5 :(得分:0)

请注意我的全球步骤保存和恢复解决方案。

保存:

global_step = tf.Variable(0, trainable=False, name='global_step')
saver.save(sess, model_path + model_name, global_step=_global_step)

还原:

if os.path.exists(model_path):
    saver.restore(sess, tf.train.latest_checkpoint(model_path))
    print("Model restore finished, current globle step: %d" % global_step.eval())

答案 6 :(得分:0)

您可以使用step变量来跟踪步骤,但是如果在代码中要将此值初始化或分配给另一个global_step变量,则该值可能不一致。

例如,您使用以下命令定义global_step = tf.Variable(0, name='global_step', trainable=False)

train_op = optimizer.minimize(loss, global_step=global_step)

分配您的训练操作:

saver.save(sess, checkpoint_path, global_step=global_step)

保存在检查点中:

saver.restore(sess, checkpoint_path) 

并从您的检查点恢复:

global_step

step的值也会恢复,但是如果您将其分配给另一个变量,例如step = global_step.eval(session=sess) ,则必须执行以下操作:

step

变量global_step包含检查点中最后保存的global_step = tf.train.get_or_create_global_step()

最好也将图中的global_step定义为零变量(如先前定义):

global_step

这将获得您的最后一个.prop('checked', true)(如果存在)或创建一个(如果不存在)。

答案 7 :(得分:0)

未按预期还原变量的原因很可能是由于它是在创建tf.Saver()对象之后创建的。

当您未显式指定tf.Saver()或未为var_list指定None时,创建var_list对象的位置很重要。对于许多程序员而言,预期的行为是调用save()方法时会保存图形中的所有变量,但事实并非如此,也许应该这样记录。创建对象时,将保存图中所有变量的快照。

除非遇到任何性能问题,否则在决定保存进度时,最安全的方法就是创建保护对象。否则,请确保在创建所有变量之后创建保护对象。

此外,传递给global_step的{​​{1}}只是用于创建文件名的计数器,与是否将其恢复为saver.save(sess, save_path, global_step=global_step)变量无关。这是一个参数错误的IMO,因为如果您要在每个纪元末尾保存进度,则最好为该参数传递纪元号。