我有使用tf.train.MonitoredTrainingSession
训练CNN的代码。
当我创建新的tf.train.MonitoredTrainingSession
时,我可以将checkpoint
目录作为输入参数传递给会话,它会自动恢复它可以找到的最新保存的checkpoint
。我可以设置hooks
来训练直到某一步。例如,如果checkpoint
的步骤为150,000
,并且我想训练到200,000
,我会将last_step
放到200,000
。
只要使用checkpoint
保存了最新的tf.train.MonitoredTrainingSession
,上述过程就能完美运行。但是,如果我尝试恢复使用普通checkpoint
保存的tf.Session
,那么所有地狱都会破裂。它无法在图表中找到一些关键字。
培训完成了:
with tf.train.MonitoredTrainingSession(
checkpoint_dir=FLAGS.retrain_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_training_steps),
tf.train.NanTensorHook(loss),
_LoggerHook()],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(train_op)
如果checkpoint_dir
属性的文件夹没有检查点,则会从头开始。如果它从之前的培训课程中保存了checkpoint
,则会恢复最新的checkpoint
并继续培训。
现在,我正在恢复最新的checkpoint
并修改一些变量并保存它们:
saver = tf.train.Saver(variables_to_restore)
ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
with tf.Session() as sess:
if ckpt and ckpt.model_checkpoint_path:
# Restores from checkpoint
saver.restore(sess, ckpt.model_checkpoint_path)
print(ckpt.model_checkpoint_path)
restored_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
FLAGS.max_training_steps = int(restored_step) + FLAGS.max_training_steps
else:
print('No checkpoint file found')
return
prune_convs(sess)
saver.save(sess, FLAGS.retrain_dir+"model.ckpt-"+restored_step)
正如您所见,就在saver.save...
之前,我正在修剪网络中的所有卷积层。无需描述完成的方式和原因。关键是网络实际上已被修改。然后我将网络保存到checkpoint
。
现在,如果我在已保存的已修改网络上部署测试,则测试工作正常。但是,当我尝试在保存的tf.train.MonitoredTrainingSession
上运行checkpoint
时,它会说:
检查点中找不到密钥conv1 / weight_loss / avg
另外,我注意到checkpoint
保存的tf.Session
的大小是checkpoint
保存的tf.train.MonitoredTrainingSession
的一半
我知道我做错了,有什么建议如何使这项工作?
答案 0 :(得分:1)
我明白了。显然,tf.Saver
不会从checkpoint
恢复所有变量。我尝试立即恢复和保存,输出只有一半。
我使用tf.train.list_variables
获取最新checkpoint
的所有变量,然后将其转换为tf.Variable
并从中创建dict
。然后我将dict
传递给tf.Saver
,它恢复了我的所有变量。
接下来是initialize
所有变量,然后修改权重。
现在它正在运作。