如何使用tf.train.MonitoredTrainingSession仅恢复某些变量

时间:2017-04-11 04:03:23

标签: tensorflow

如何告诉tf.train.MonitoredTrainingSession只恢复一部分变量,并对其余部分执行初始化?

从cifar10教程开始.. https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_train.py

..我创建了要恢复和初始化的变量列表,并使用我传递给MonitoredTrainingSession的Scaffold指定它们:

  restoration_saver = Saver(var_list=restore_vars)
  restoration_scaffold = Scaffold(init_op=variables_initializer(init_vars),
                                  ready_op=constant([]),
                                  saver=restoration_saver)

但是这会出现以下错误:

  

RuntimeError:Init操作没有为local_init准备模型。 Init op:group_deps,init fn:None,error:变量未初始化:conv2a / T,conv2b / T,[...]

..错误消息中列出的未初始化变量是我的“init_vars”列表中的变量。

SessionManager.prepare_session()引发异常。该方法的源代码似乎表明如果从检查点恢复会话,则不会运行init_op。因此看起来您可以恢复变量或初始化变量,但不能同时使用。

5 个答案:

答案 0 :(得分:3)

好吧,我怀疑,通过基于现有的tf.training.SessionManager实现一个新的RefinementSessionManager类,我得到了我想要的东西。这两个类几乎完全相同,只是我修改了prepare_session方法以调用init_op,无论模型是否从检查点加载。

这允许我从检查点加载变量列表并初始化init_op中的其余变量。

我的prepare_session方法是:

  def prepare_session(self, master, init_op=None, saver=None,
                  checkpoint_dir=None, wait_for_checkpoint=False,
                  max_wait_secs=7200, config=None, init_feed_dict=None,
                  init_fn=None):

    sess, is_loaded_from_checkpoint = self._restore_checkpoint(
    master,
    saver,
    checkpoint_dir=checkpoint_dir,
    wait_for_checkpoint=wait_for_checkpoint,
    max_wait_secs=max_wait_secs,
    config=config)

    # [removed] if not is_loaded_from_checkpoint:
    # we still want to run any supplied initialization on models that
    # were loaded from checkpoint.

    if not is_loaded_from_checkpoint and init_op is None and not init_fn and self._local_init_op is None:
      raise RuntimeError("Model is not initialized and no init_op or "
                     "init_fn or local_init_op was given")
    if init_op is not None:
      sess.run(init_op, feed_dict=init_feed_dict)
    if init_fn:
      init_fn(sess)

    # [...]

希望这有助于其他人。

答案 1 :(得分:1)

来自@avital的提示更加完整:将脚手架对象传递到MonitoredTrainingSessionlocal_init_op a ready_for_local_init_op。像这样:

model_ready_for_local_init_op = tf.report_uninitialized_variables(
            var_list=var_list)
model_init_tmp_vars = tf.variables_initializer(var_list)
scaffold = tf.train.Scaffold(saver=model_saver,
               local_init_op = model_init_tmp_vars,
               ready_for_local_init_op = model_ready_for_local_init_op)
with tf.train.MonitoredTrainingSession(...,
                scaffold=scaffold,
                ...) as mon_sess:
   ...

答案 2 :(得分:0)

您可以使用local_init_op参数解决此问题,该参数在从检查点加载后运行。

答案 3 :(得分:0)

Scaffold的{​​{3}}包含以下内容:

  • init_op
  • ready_op
  • local_init_op
  • ready_for_local_init_op

init_op仅在我们从检查点还原时被调用。

if not is_loaded_from_checkpoint:
  if init_op is None and not init_fn and self._local_init_op is None:
    raise RuntimeError("Model is not initialized and no init_op or "
                   "init_fn or local_init_op was given")
  if init_op is not None:
    sess.run(init_op, feed_dict=init_feed_dict)
  if init_fn:
    init_fn(sess)

实际上init_op在这里无济于事。如果您可以编写新的SessionManager,则可以关注@ user550701。我们也可以使用local_init_op,但在分布式情况下可能会有些棘手。

Scaffold将为我们生成默认的init_oplocal_init_oparguments

  • init_op:将初始化tf.global_variables
  • local_init_op:将初始化tf.local_variables

我们应该初始化变量,并且不要同时破坏默认机制。

一个工人的情况

您可以这样创建local_init_op

target_collection = [] # Put your target tensors here
collection = tf.local_variables() + target_collection
local_init_op = tf.variables_initializer(collection)
ready_for_local_init_op = tf.report_uninitialized_variables(collection)

分布情况

我们应该注意target_collection的重复初始化,因为local_init_op将在多个工作程序上被多次调用。如果变量是局部变量,则没有区别。如果它们是全局变量,则应确保仅将其初始化一次。为了解决重复的问题,我们可以操纵collection变量。对于首席工作者,它既包含局部变量,又包含我们的target_collection。对于非首席员工,我们只将局部变量放入其中。

if is_chief:
   collection = tf.local_variables() + target_collection
else:
   collection = tf.local_variables()

总而言之,这有点棘手,但是我们不必侵入tensorflow。

答案 4 :(得分:0)

我遇到了同样的问题,我的解决方法是

checkpoint_restore_dir_for_monitered_session = None
scaffold = None
if params.restore:
    checkpoint_restore_dir_for_monitered_session = checkpoint_save_dir

    restore_exclude_name_list = params.restore_exclude_name_list
    if len(restore_exclude_name_list) != 0:
        variables_to_restore, variables_dont_restore = get_restore_var_list(restore_exclude_name_list)
        saver_for_restore = tf.train.Saver(var_list=variables_to_restore, name='saver_for_restore')
        ready_for_local_init_op = tf.report_uninitialized_variables(variables_to_restore.values())

        local_init_op = tf.group([
            tf.initializers.local_variables(),
            tf.initializers.variables(variables_dont_restore)
            ])

        scaffold = tf.train.Scaffold(saver=saver_for_restore,
                ready_for_local_init_op=ready_for_local_init_op,
                local_init_op=local_init_op)

with tf.train.MonitoredTrainingSession(
        checkpoint_dir=checkpoint_restore_dir_for_monitered_session, 
        save_checkpoint_secs=None,  # don't save ckpt
        hooks=train_hooks,
        config=config,
        scaffold=scaffold,
        summary_dir=params.log_dir) as sess:
    pass

在此代码段中,get_restore_var_list得到variables_to_restorevariables_dont_restore
saver_for_restore仅还原variables_to_restore中的变量,之后将对其进行检查并通过ready_for_local_init_op进行传递。
然后local_init_op将运行,初始化local_variables()variables_dont_restore(也许是tf.variance_scaling_initializer ...)。