如何从张量流中保存的检查点恢复特定范围的变量?

时间:2017-03-02 03:35:11

标签: tensorflow neural-network deep-learning

import tensorflow as tf
saver = tf.train.Saver() 
saver.restore(...)

但是saver.restore只有恢复整个图表的选项。我想只恢复特定范围内的变量。

提前致谢!

1 个答案:

答案 0 :(得分:5)

假设您在范围InceptionV1中拥有Google的InceptionNet模型,并且除了要重新训练的范围InceptionRetrained中的最后一个图层外,您希望加载它。

假设您已经开始重新训练最后一层并且您通过saver2.save(session, 'last_layer.ckpt')创建了 last_layer.ckpt 文件,以下是如何从两个检查点恢复网络。

saver1 = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='InceptionV1'))
saver1.restore(session, 'inception_model_from_google.ckpt')

saver2 = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='InceptionRetrained'))
saver2.restore(session, 'last_layer.ckpt')

如果您只重新训练最后一层,请不要忘记通过使用var_list参数调用优化器来禁用网络上的渐变传播(节省时间)。

tf.train.Optimizer(0.0001).minimize(
            loss, var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Inceptionretrained'))