如何告诉tf.train.MonitoredTrainingSession只恢复一部分变量,并对其余部分执行初始化?
从cifar10教程开始.. https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_train.py
..我创建了要恢复和初始化的变量列表,并使用我传递给MonitoredTrainingSession的Scaffold指定它们:
restoration_saver = Saver(var_list=restore_vars)
restoration_scaffold = Scaffold(init_op=variables_initializer(init_vars),
ready_op=constant([]),
saver=restoration_saver)
但是这会出现以下错误:
RuntimeError:Init操作没有为local_init准备模型。 Init op:group_deps,init fn:None,error:变量未初始化:conv2a / T,conv2b / T,[...]
..错误消息中列出的未初始化变量是我的“init_vars”列表中的变量。
SessionManager.prepare_session()引发异常。该方法的源代码似乎表明如果从检查点恢复会话,则不会运行init_op。因此看起来您可以恢复变量或初始化变量,但不能同时使用。
答案 0 :(得分:3)
好吧,我怀疑,通过基于现有的tf.training.SessionManager实现一个新的RefinementSessionManager类,我得到了我想要的东西。这两个类几乎完全相同,只是我修改了prepare_session方法以调用init_op,无论模型是否从检查点加载。
这允许我从检查点加载变量列表并初始化init_op中的其余变量。
我的prepare_session方法是:
def prepare_session(self, master, init_op=None, saver=None,
checkpoint_dir=None, wait_for_checkpoint=False,
max_wait_secs=7200, config=None, init_feed_dict=None,
init_fn=None):
sess, is_loaded_from_checkpoint = self._restore_checkpoint(
master,
saver,
checkpoint_dir=checkpoint_dir,
wait_for_checkpoint=wait_for_checkpoint,
max_wait_secs=max_wait_secs,
config=config)
# [removed] if not is_loaded_from_checkpoint:
# we still want to run any supplied initialization on models that
# were loaded from checkpoint.
if not is_loaded_from_checkpoint and init_op is None and not init_fn and self._local_init_op is None:
raise RuntimeError("Model is not initialized and no init_op or "
"init_fn or local_init_op was given")
if init_op is not None:
sess.run(init_op, feed_dict=init_feed_dict)
if init_fn:
init_fn(sess)
# [...]
希望这有助于其他人。
答案 1 :(得分:1)
来自@avital的提示更加完整:将脚手架对象传递到MonitoredTrainingSession
,local_init_op
和 a ready_for_local_init_op
。像这样:
model_ready_for_local_init_op = tf.report_uninitialized_variables(
var_list=var_list)
model_init_tmp_vars = tf.variables_initializer(var_list)
scaffold = tf.train.Scaffold(saver=model_saver,
local_init_op = model_init_tmp_vars,
ready_for_local_init_op = model_ready_for_local_init_op)
with tf.train.MonitoredTrainingSession(...,
scaffold=scaffold,
...) as mon_sess:
...
答案 2 :(得分:0)
您可以使用local_init_op
参数解决此问题,该参数在从检查点加载后运行。
答案 3 :(得分:0)
Scaffold
的{{3}}包含以下内容:
init_op
仅在我们不从检查点还原时被调用。
if not is_loaded_from_checkpoint:
if init_op is None and not init_fn and self._local_init_op is None:
raise RuntimeError("Model is not initialized and no init_op or "
"init_fn or local_init_op was given")
if init_op is not None:
sess.run(init_op, feed_dict=init_feed_dict)
if init_fn:
init_fn(sess)
实际上init_op
在这里无济于事。如果您可以编写新的SessionManager
,则可以关注@ user550701。我们也可以使用local_init_op
,但在分布式情况下可能会有些棘手。
Scaffold
将为我们生成默认的init_op
和local_init_op
:arguments
tf.global_variables
tf.local_variables
我们应该初始化变量,并且不要同时破坏默认机制。
您可以这样创建local_init_op
:
target_collection = [] # Put your target tensors here
collection = tf.local_variables() + target_collection
local_init_op = tf.variables_initializer(collection)
ready_for_local_init_op = tf.report_uninitialized_variables(collection)
我们应该注意target_collection
的重复初始化,因为local_init_op
将在多个工作程序上被多次调用。如果变量是局部变量,则没有区别。如果它们是全局变量,则应确保仅将其初始化一次。为了解决重复的问题,我们可以操纵collection
变量。对于首席工作者,它既包含局部变量,又包含我们的target_collection
。对于非首席员工,我们只将局部变量放入其中。
if is_chief:
collection = tf.local_variables() + target_collection
else:
collection = tf.local_variables()
总而言之,这有点棘手,但是我们不必侵入tensorflow。
答案 4 :(得分:0)
我遇到了同样的问题,我的解决方法是
checkpoint_restore_dir_for_monitered_session = None
scaffold = None
if params.restore:
checkpoint_restore_dir_for_monitered_session = checkpoint_save_dir
restore_exclude_name_list = params.restore_exclude_name_list
if len(restore_exclude_name_list) != 0:
variables_to_restore, variables_dont_restore = get_restore_var_list(restore_exclude_name_list)
saver_for_restore = tf.train.Saver(var_list=variables_to_restore, name='saver_for_restore')
ready_for_local_init_op = tf.report_uninitialized_variables(variables_to_restore.values())
local_init_op = tf.group([
tf.initializers.local_variables(),
tf.initializers.variables(variables_dont_restore)
])
scaffold = tf.train.Scaffold(saver=saver_for_restore,
ready_for_local_init_op=ready_for_local_init_op,
local_init_op=local_init_op)
with tf.train.MonitoredTrainingSession(
checkpoint_dir=checkpoint_restore_dir_for_monitered_session,
save_checkpoint_secs=None, # don't save ckpt
hooks=train_hooks,
config=config,
scaffold=scaffold,
summary_dir=params.log_dir) as sess:
pass
在此代码段中,get_restore_var_list
得到variables_to_restore
和variables_dont_restore
。
saver_for_restore
仅还原variables_to_restore
中的变量,之后将对其进行检查并通过ready_for_local_init_op
进行传递。
然后local_init_op
将运行,初始化local_variables()
和variables_dont_restore
(也许是tf.variance_scaling_initializer
...)。