我的目的很简单明了:在图形被部分修改之后,如何从之前的日志检查点文件恢复未更改的变量/参数?(更好地使用MonitoredTrainingSession)
我在这里对代码进行测试: https://github.com/tensorflow/models/tree/master/research/resnet
在resnet_model.py第116-118行,原始代码(或图表)为:
with tf.variable_scope('logit'):
logits = self._fully_connected(x, self.hps.num_classes)
self.predictions = tf.nn.softmax(logits)
with tf.variable_scope('costs'):
xent = tf.nn.softmax_cross_entropy_with_logits(
logits=logits, labels=self.labels)
self.cost = tf.reduce_mean(xent, name='xent')
self.cost += self._decay()
在第一次培训后,我获得了检查点文件。 然后我将代码修改为:
with tf.variable_scope('logit_modified'):
logits_modified = self._fully_connected('fc_1',x, 48)
#self.predictions = tf.nn.softmax(logits)
with tf.variable_scope('logit_2'):
logits_2 = self._fully_connected('fc_2', logits_modified,
self.hps.num_classes)
self.predictions = tf.nn.softmax(logits_2)
with tf.variable_scope('costs'):
xent = tf.nn.softmax_cross_entropy_with_logits(
logits=logits_2, labels=self.labels)
self.cost = tf.reduce_mean(xent, name='xent')
self.cost += self._decay()
然后我尝试使用latested API tf.train.MonitoredTrainingSession来恢复第一次训练中获得的检查点。我已经尝试了多种方法来做到这一点,但它们都不起作用。
尝试1: 如果我在MonitoredTrainingSession中不使用scaffold:
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.log_root,
#scaffold=scaffold,
hooks=[logging_hook, _LearningRateSetterHook()],
chief_only_hooks=[summary_hook],
save_checkpoint_secs = 600,
# Since we provide a SummarySaverHook, we need to disable default
# SummarySaverHook. To do that we set save_summaries_steps to 0.
save_summaries_steps=None,
save_summaries_secs=None,
config=tf.ConfigProto(allow_soft_placement=True),
stop_grace_period_secs=120,
log_step_count_steps=100) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(_train_op)
错误消息是:
2017-12-29 10:33:30.699061:W tensorflow / core / framework / op_kernel.cc:1192]未找到:密钥logit_modified / fc_1 /偏差/检查点中未找到动量 ...
虽然会话似乎试图根据修改后的图形进行恢复,但不是新图形和先前检查点文件中存在的变量(换句话说,所有图层都排除了最终的2)。
尝试2: 受到使用tf.train.Supervisor的转学习代码的启发: https://github.com/kwotsin/transfer_learning_tutorial/blob/master/train_flowers.py,来自第251行。
首先我修改了resnet_model.py中的代码,添加以下行:
self.variables_to_restore = tf.contrib.framework.get_variables_to_restore(
exclude=["logit_modified", "logit_2"])
然后MonitoredTrainingSession中的脚手架变为:
saver = tf.train.Saver(variables_to_restore)
def restore_fn(sess):
return saver.restore(sess, FLAGS.log_root)
scaffold = tf.train.Scaffold(saver=saver, init_fn = restore_fn)
不幸的是,显示了以下错误消息:
RuntimeError:Init操作没有为local_init准备模型。 Init op:group_deps,init fn:at 0x7f0ec26f4320&gt ;, error:变量未初始化:logit_modified / fc_1 / DW,...
似乎最后2层未正确恢复,因此其余图层未恢复。
尝试3: 我也尝试过这里列出的方法:How to use tf.train.MonitoredTrainingSession to restore only certain variables,但它们都不起作用。
我知道还有其他方法可以恢复,例如https://github.com/tensorflow/models/blob/6fb14a790c283a922119b19632e3f7b8e5c0a729/research/inception/inception/inception_model.py中的代码,但它们是嵌套的,并且不够通用,无法轻松应用于其他模型。这就是我想使用“MonitoredTrainingSession”的原因。
那么如何使用“MonitoredTrainingSession”来恢复tensorflow中的部分检查点?
答案 0 :(得分:1)
好的,最后我弄清楚了。
在这里阅读monitored_session.py后: https://github.com/tensorflow/tensorflow/blob/4806cb0646bd21f713722bd97c0d0262c575f7e0/tensorflow/python/training/monitored_session.py,我发现关键(而且非常棘手)的一点是更改为新的空检查点目录,以便MonitoredTrainingSession不会忽略init_op或init_fn。 然后你可以使用以下代码来构建你的init_fn(为了恢复检查点)以及脚手架:
variables_to_restore = tf.contrib.framework.get_variables_to_restore(
exclude=['XXX'])
init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
ckpt.model_checkpoint_path, variables_to_restore)
def InitAssignFn(scaffold,sess):
sess.run(init_assign_op, init_feed_dict)
scaffold = tf.train.Scaffold(saver=tf.train.Saver(), init_fn=InitAssignFn)
请记住上面的 ckpt.model_checkpoint_path 是您的旧检查点路径,其中包含预先训练的文件。我上面提到的新的空检查点目录表示参数" checkpoint_dir"监控培训会议在这里:
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.log_root_2,...) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(_train_op)
我修改的代码的第一段派生自tf.slim中的learning.py,来自第134行: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/learning.py
加: 感谢这个Q& A的灵感,虽然解决方案有点不同: What's the recommend way of restoring only parts model in distributed tensorflow