如何正确保存还原和修改的模型?

时间:2019-05-20 12:38:01

标签: tensorflow

背景是我保存了训练有素的模型,现在我需要将其恢复到新模型。

问题在于新模型与以前的模型几乎没有什么不同(最后增加了一些新层)。所以我用

var_list = tf.contrib.framework.get_variables_to_restore(exclude=new_variables_name)
saver = tf.train.Saver(var_list)

恢复训练后的变量并避免新层中的变量。

但是,当我通过MonitoredTrainingSession将其还原并保存为:

with tf.train.MonitoredTrainingSession(checkpoint_dir=FLAGS.train_dir,
    hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
           tf.train.NanTensorHook(loss)],
    config=tf.ConfigProto(
        log_device_placement=FLAGS.log_device_placement)) as mon_sess:
    saver.restore(mon_sess, ".../model.ckpt-100000")
    while not mon_sess.should_stop():
      mon_sess.run(train_op)

可以发现新变量根本没有保存。虽然我知道这是因为Saver对象在创建时没有收到这些变量的名称,但是我还是无法弄清楚如何部分还原模型以及完全保存模型,因为我对Tensorflow并不陌生。

感谢您的回复!

0 个答案:

没有答案