我像这样保存会话状态:
self._saver = tf.saver()
self._saver.save(self._session, '/network', global_step=self._time)
当我稍后恢复时,我想获取我从中恢复的检查点的global_step的值。这是为了从中设置一些超参数。
执行此操作的hacky方法是运行并解析检查点目录中的文件名。但是,必须有更好的,内置的方法来做到这一点?
答案 0 :(得分:24)
一般模式是使用global_step
变量来跟踪步骤
global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)
然后你可以用
保存saver.save(sess, save_path, global_step=global_step)
还原时,global_step
的值也会恢复
答案 1 :(得分:5)
这有点像黑客,但其他答案对我来说根本不起作用
string duration = (starttime.HasValue && endtime.HasValue)
? (endtime - starttime).ToString()
: "0";
更新9/2017
我不确定这是否因更新而开始工作,但以下方法似乎可以有效地使global_step更新并正确加载:
创建两个操作。一个用于保存global_step而另一个用于增加它:
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
#Extract from checkpoint filename
step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])
现在,在您的训练循环中,每次运行训练操作时都会运行增量操作。
global_step = tf.Variable(0, trainable=False, name='global_step')
increment_global_step = tf.assign_add(global_step,1,
name = 'increment_global_step')
如果您想在任何时候将整数步长值检索为整数,请在加载模型后使用以下命令:
sess.run([train_op,increment_global_step],feed_dict=feed_dict)
这对于创建文件名或计算当前时期的内容非常有用,而不需要第二个tensorflow变量来保存该值。例如,计算加载时的当前纪元就像:
sess.run(global_step)
答案 2 :(得分:1)
我和Lawrence Du有同样的问题,我找不到通过恢复模型获得global_step的方法。所以我将his hack应用于我正在使用的the inception v3 training code in the Tensorflow/models github repo。下面的代码还包含与pretrained_model_checkpoint_path
相关的修补程序。
如果您有更好的解决方案,或者知道我缺少什么,请发表评论!
无论如何,这段代码对我有用:
...
# When not restoring start at 0
last_step = 0
if FLAGS.pretrained_model_checkpoint_path:
# A model consists of three files, use the base name of the model in
# the checkpoint path. E.g. my-model-path/model.ckpt-291500
#
# Because we need to give the base name you can't assert (will always fail)
# assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
variables_to_restore = tf.get_collection(
slim.variables.VARIABLES_TO_RESTORE)
restorer = tf.train.Saver(variables_to_restore)
restorer.restore(sess, FLAGS.pretrained_model_checkpoint_path)
print('%s: Pre-trained model restored from %s' %
(datetime.now(), FLAGS.pretrained_model_checkpoint_path))
# HACK : global step is not restored for some unknown reason
last_step = int(os.path.basename(FLAGS.pretrained_model_checkpoint_path).split('-')[1])
# assign to global step
sess.run(global_step.assign(last_step))
...
for step in range(last_step + 1, FLAGS.max_steps):
...
答案 3 :(得分:1)
作为tensorflow变量(将在会话中评估)
global_step = tf.train.get_or_create_global_step()
# use global_step variable to calculate your hyperparameter
# this variable will be evaluated later in the session
saver = tf.train.Saver()
with tf.Session() as sess:
# restore all variables from checkpoint
saver.restore(sess, checkpoint_path)
# than init table and local variables and start training/evaluation ...
或:作为numpy整数(没有任何会话):
reader = tf.train.NewCheckpointReader(absolute_checkpoint_path)
global_step = reader.get_tensor('global_step')
至少有两种方法可以从检查点检索全局变量。作为tensorflow变量或numpy整数。如果global_step
的{{3}}方法中未提供Saver
作为参数,则无法解析文件名。对于预训练的模型,请参见答案末尾的说明。
如果您需要global_step
变量来计算某些超参数,则可以使用save
。这将返回一个tensorflow变量。由于该变量将在会话的稍后阶段进行评估,因此您只能使用tensorflow操作来计算您的超参数。因此,例如:max(global_step, 100)
将不起作用。您必须使用等效的tensorflow tf.maximum(global_step, 100)
,可以在会话的稍后进行评估。
在会话中,您可以使用tf.train.get_or_create_global_step()
global_step = tf.train.get_or_create_global_step()
# use global_step variable to calculate your hyperparameter
# this variable will be evaluated later in the session
hyper_parameter = tf.maximum(global_step, 100)
saver = tf.train.Saver()
with tf.Session() as sess:
# restore all variables from checkpoint
saver.restore(sess, checkpoint_path)
# than init table and local variables and start training/evaluation ...
# for verification you can print the global step and your hyper parameter
print(sess.run([global_step, hyper_parameter]))
如果您需要全局step变量作为标量而不启动会话,则还可以直接从检查点文件读取此变量。您只需要一个NewCheckpointReader
。由于旧版tensorflow版本中的saver.restore(sess, checkpoint_path)
,您应该将检查点文件的路径转换为绝对路径。使用阅读器,您可以将模型的所有张量作为numpy变量获取。
全局步骤变量的名称是定义为'global_step'
的常量字符串bug。
absolute_checkpoint_path = os.path.abspath(checkpoint_path)
reader = tf.train.NewCheckpointReader(absolute_checkpoint_path)
global_step = reader.get_tensor(tf.GraphKeys.GLOBAL_STEP)
注释预训练模型:在大多数在线可用的预训练模型中,全局步长重置为零。因此,这些模型可用于初始化模型参数以进行微调,而无需覆盖全局步骤。
答案 4 :(得分:0)
当前的0.10rc0版本似乎有所不同,不再有tf.saver()了。现在它是tf.train.Saver()。另外,save命令将信息添加到global_step的save_path文件名中,因此我们不能在同一个save_path上调用restore,因为那不是实际的保存文件。
我现在看到的最简单的方法是使用SessionManager以及这样的保护程序:
my_checkpoint_dir = "/tmp/checkpoint_dir"
# make a saver to use with SessionManager for restoring
saver = tf.train.Saver()
# Build an initialization operation to run below.
init = tf.initialize_all_variables()
# use a SessionManager to help with automatic variable restoration
sm = tf.train.SessionManager()
# try to find the latest checkpoint in my_checkpoint_dir, then create a session with that restored
# if no such checkpoint, then call the init_op after creating a new session
sess = sm.prepare_session("", init_op=init, saver=saver, checkpoint_dir=my_checkpoint_dir))
那就是它。现在你有一个会话要么从my_checkpoint_dir恢复(在调用之前确保该目录存在),或者如果那里没有检查点那么它会创建一个新会话并调用init_op来初始化你的变量
当你想要保存时,你只需保存到该目录中你想要的任何名称并传递global_step。这里是一个例子,我将循环变量保存在循环中作为global_step,所以它返回到那一点,如果你杀死程序并重新启动它,以便恢复检查点:
checkpoint_path = os.path.join(my_checkpoint_dir, 'model.ckpt')
saver.save(sess, checkpoint_path, global_step=step)
这会在my_checkpoint_dir中创建文件,例如" model.ckpt-1000"其中1000是传入的global_step。如果它继续运行,那么你会更像" model.ckpt-2000"。上面的SessionManager在程序重启时选取最新的一个。 checkpoint_path可以是您想要的任何文件名,只要它在checkpoint_dir中即可。 save()将创建附加了global_step的文件(如上所示)。它还创建了一个"检查点"索引文件,这是SessionManager随后找到最新保存检查点的方式。
答案 5 :(得分:0)
请注意我的全球步骤保存和恢复解决方案。
保存:
global_step = tf.Variable(0, trainable=False, name='global_step')
saver.save(sess, model_path + model_name, global_step=_global_step)
还原:
if os.path.exists(model_path):
saver.restore(sess, tf.train.latest_checkpoint(model_path))
print("Model restore finished, current globle step: %d" % global_step.eval())
答案 6 :(得分:0)
您可以使用step
变量来跟踪步骤,但是如果在代码中要将此值初始化或分配给另一个global_step
变量,则该值可能不一致。
例如,您使用以下命令定义global_step = tf.Variable(0, name='global_step', trainable=False)
:
train_op = optimizer.minimize(loss, global_step=global_step)
分配您的训练操作:
saver.save(sess, checkpoint_path, global_step=global_step)
保存在检查点中:
saver.restore(sess, checkpoint_path)
并从您的检查点恢复:
global_step
step
的值也会恢复,但是如果您将其分配给另一个变量,例如step = global_step.eval(session=sess)
,则必须执行以下操作:
step
变量global_step
包含检查点中最后保存的global_step = tf.train.get_or_create_global_step()
。
最好也将图中的global_step定义为零变量(如先前定义):
global_step
这将获得您的最后一个.prop('checked', true)
(如果存在)或创建一个(如果不存在)。
答案 7 :(得分:0)
未按预期还原变量的原因很可能是由于它是在创建tf.Saver()
对象之后创建的。
当您未显式指定tf.Saver()
或未为var_list
指定None
时,创建var_list
对象的位置很重要。对于许多程序员而言,预期的行为是调用save()
方法时会保存图形中的所有变量,但事实并非如此,也许应该这样记录。创建对象时,将保存图中所有变量的快照。
除非遇到任何性能问题,否则在决定保存进度时,最安全的方法就是创建保护对象。否则,请确保在创建所有变量之后创建保护对象。
此外,传递给global_step
的{{1}}只是用于创建文件名的计数器,与是否将其恢复为saver.save(sess, save_path, global_step=global_step)
变量无关。这是一个参数错误的IMO,因为如果您要在每个纪元末尾保存进度,则最好为该参数传递纪元号。