Tensorflow使用训练模型的权重来恢复训练

时间:2018-02-17 05:19:11

标签: python tensorflow

我想在张量流中训练模型,然后导出训练的权重(和偏差),并使用它们在具有不同较小学习率的新图中恢复训练。我不需要保存或使用先前训练模型的架构(我将再次定义模型架构)。我只需要将新权重初始化为等于训练权重。有一种直截了当的方法吗?感谢。

在尝试恢复和重用权重时,我使用了以下代码,但它没有成功:

tf.reset_default_graph()
sess = tf.Session()
new_saver = tf.train.import_meta_graph("model20256.meta") 
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
graph = tf.get_default_graph()

#Import trained weights as tensors
W1__ = graph.get_tensor_by_name('W1:0')
b1__ = graph.get_tensor_by_name('b1:0')
W2__ = graph.get_tensor_by_name('W2:0')
b2__ = graph.get_tensor_by_name('b2:0')

#Create new weight variables initialized by the trained weights    
W1_=tf.Variable(W1__)
W2_=tf.Variable(W2__)
b1_=tf.Variable(b1__)
b2_=tf.Variable(b2__)


sess_ = tf.Session()
sess_.run(tf.global_variables_initializer())

#Define a new graph
g = tf.Graph()
with g.as_default():
    #Copy variables to the new graph
    W1=tf.contrib.copy_graph.copy_variable_to_graph(W1_,g)
    W2=tf.contrib.copy_graph.copy_variable_to_graph(W2_,g)
    b1=tf.contrib.copy_graph.copy_variable_to_graph(b1_,g)
    b2=tf.contrib.copy_graph.copy_variable_to_graph(b2_,g)

1 个答案:

答案 0 :(得分:0)

在会话中,将sess_.run(tf.global_variables_initializer())替换为new_saver.restore(sess, tf.train.latest_checkpoint('./'))