import tensorflow as tf
saver = tf.train.Saver()
saver.restore(...)
但是saver.restore只有恢复整个图表的选项。我想只恢复特定范围内的变量。
提前致谢!
答案 0 :(得分:5)
假设您在范围InceptionV1
中拥有Google的InceptionNet模型,并且除了要重新训练的范围InceptionRetrained
中的最后一个图层外,您希望加载它。
假设您已经开始重新训练最后一层并且您通过saver2.save(session, 'last_layer.ckpt')
创建了 last_layer.ckpt 文件,以下是如何从两个检查点恢复网络。
saver1 = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='InceptionV1'))
saver1.restore(session, 'inception_model_from_google.ckpt')
saver2 = tf.train.Saver(var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='InceptionRetrained'))
saver2.restore(session, 'last_layer.ckpt')
如果您只重新训练最后一层,请不要忘记通过使用var_list
参数调用优化器来禁用网络上的渐变传播(节省时间)。
tf.train.Optimizer(0.0001).minimize(
loss, var_list=tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Inceptionretrained'))