在张量流中,当图形被修改时,如何使用“MonitoredTrainingSession”仅恢复部分检查点?

时间:2017-12-29 05:19:48

标签: tensorflow

我的目的很简单明了:在图形被部分修改之后,如何从之前的日志检查点文件恢复未更改的变量/参数?(更好地使用MonitoredTrainingSession)

我在这里对代码进行测试: https://github.com/tensorflow/models/tree/master/research/resnet

在resnet_model.py第116-118行,原始代码(或图表)为:

with tf.variable_scope('logit'):
    logits = self._fully_connected(x, self.hps.num_classes)
    self.predictions = tf.nn.softmax(logits)
with tf.variable_scope('costs'):
    xent = tf.nn.softmax_cross_entropy_with_logits(
    logits=logits, labels=self.labels)
    self.cost = tf.reduce_mean(xent, name='xent')
    self.cost += self._decay()

在第一次培训后,我获得了检查点文件。 然后我将代码修改为:

with tf.variable_scope('logit_modified'):
    logits_modified = self._fully_connected('fc_1',x, 48)
    #self.predictions = tf.nn.softmax(logits)    
with tf.variable_scope('logit_2'):
    logits_2 = self._fully_connected('fc_2', logits_modified, 
    self.hps.num_classes)
    self.predictions = tf.nn.softmax(logits_2)
with tf.variable_scope('costs'):
    xent = tf.nn.softmax_cross_entropy_with_logits(
    logits=logits_2, labels=self.labels)
    self.cost = tf.reduce_mean(xent, name='xent')
    self.cost += self._decay()

然后我尝试使用latested API tf.train.MonitoredTrainingSession来恢复第一次训练中获得的检查点。我已经尝试了多种方法来做到这一点,但它们都不起作用。

尝试1: 如果我在MonitoredTrainingSession中不使用scaffold:

with tf.train.MonitoredTrainingSession(
    checkpoint_dir=FLAGS.log_root,
    #scaffold=scaffold,
    hooks=[logging_hook, _LearningRateSetterHook()],
    chief_only_hooks=[summary_hook],
    save_checkpoint_secs = 600,
    # Since we provide a SummarySaverHook, we need to disable default
    # SummarySaverHook. To do that we set save_summaries_steps to 0.
    save_summaries_steps=None,
    save_summaries_secs=None,
    config=tf.ConfigProto(allow_soft_placement=True),
    stop_grace_period_secs=120,
    log_step_count_steps=100) as mon_sess:
while not mon_sess.should_stop():
    mon_sess.run(_train_op)

错误消息是:

  

2017-12-29 10:33:30.699061:W tensorflow / core / framework / op_kernel.cc:1192]未找到:密钥logit_modified / fc_1 /偏差/检查点中未找到动量   ...

虽然会话似乎试图根据修改后的图形进行恢复,但不是新图形和先前检查点文件中存在的变量(换句话说,所有图层都排除了最终的2)。

尝试2: 受到使用tf.train.Supervisor的转学习代码的启发: https://github.com/kwotsin/transfer_learning_tutorial/blob/master/train_flowers.py,来自第251行。

首先我修改了resnet_model.py中的代码,添加以下行:

self.variables_to_restore = tf.contrib.framework.get_variables_to_restore(
exclude=["logit_modified", "logit_2"])

然后MonitoredTrainingSession中的脚手架变为:

saver = tf.train.Saver(variables_to_restore)
def restore_fn(sess):
    return saver.restore(sess, FLAGS.log_root)
scaffold = tf.train.Scaffold(saver=saver, init_fn = restore_fn)

不幸的是,显示了以下错误消息:

  

RuntimeError:Init操作没有为local_init准备模型。 Init op:group_deps,init fn:at 0x7f0ec26f4320&gt ;, error:变量未初始化:logit_modified / fc_1 / DW,...

似乎最后2层未正确恢复,因此其余图层未恢复。

尝试3: 我也尝试过这里列出的方法:How to use tf.train.MonitoredTrainingSession to restore only certain variables,但它们都不起作用。

我知道还有其他方法可以恢复,例如https://github.com/tensorflow/models/blob/6fb14a790c283a922119b19632e3f7b8e5c0a729/research/inception/inception/inception_model.py中的代码,但它们是嵌套的,并且不够通用,无法轻松应用于其他模型。这就是我想使用“MonitoredTrainingSession”的原因。

那么如何使用“MonitoredTrainingSession”来恢复tensorflow中的部分检查点?

1 个答案:

答案 0 :(得分:1)

好的,最后我弄清楚了。

在这里阅读monitored_session.py后: https://github.com/tensorflow/tensorflow/blob/4806cb0646bd21f713722bd97c0d0262c575f7e0/tensorflow/python/training/monitored_session.py,我发现关键(而且非常棘手)的一点是更改为新的空检查点目录,以便MonitoredTrainingSession不会忽略init_op或init_fn。 然后你可以使用以下代码来构建你的init_fn(为了恢复检查点)以及脚手架:

variables_to_restore = tf.contrib.framework.get_variables_to_restore(
    exclude=['XXX'])    
init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
    ckpt.model_checkpoint_path, variables_to_restore)
def InitAssignFn(scaffold,sess):
    sess.run(init_assign_op, init_feed_dict)

scaffold = tf.train.Scaffold(saver=tf.train.Saver(), init_fn=InitAssignFn)

请记住上面的 ckpt.model_checkpoint_path 是您的旧检查点路径,其中包含预先训练的文件。我上面提到的新的空检查点目录表示参数" checkpoint_dir"监控培训会议在这里:

with tf.train.MonitoredTrainingSession(
    checkpoint_dir=FLAGS.log_root_2,...) as mon_sess:
while not mon_sess.should_stop():
    mon_sess.run(_train_op)

我修改的代码的第一段派生自tf.slim中的learning.py,来自第134行: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/learning.py

加: 感谢这个Q& A的灵感,虽然解决方案有点不同: What's the recommend way of restoring only parts model in distributed tensorflow