背景是我保存了训练有素的模型,现在我需要将其恢复到新模型。
问题在于新模型与以前的模型几乎没有什么不同(最后增加了一些新层)。所以我用
var_list = tf.contrib.framework.get_variables_to_restore(exclude=new_variables_name)
saver = tf.train.Saver(var_list)
恢复训练后的变量并避免新层中的变量。
但是,当我通过MonitoredTrainingSession将其还原并保存为:
with tf.train.MonitoredTrainingSession(checkpoint_dir=FLAGS.train_dir,
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
tf.train.NanTensorHook(loss)],
config=tf.ConfigProto(
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
saver.restore(mon_sess, ".../model.ckpt-100000")
while not mon_sess.should_stop():
mon_sess.run(train_op)
可以发现新变量根本没有保存。虽然我知道这是因为Saver对象在创建时没有收到这些变量的名称,但是我还是无法弄清楚如何部分还原模型以及完全保存模型,因为我对Tensorflow并不陌生。
感谢您的回复!