在TensorFlow的低级API中,是否可以使用优化器保存图形并在另一个文件中继续训练?

时间:2018-09-01 21:53:11

标签: python tensorflow

我创建了一个文件,在其中创建模型并使用tf.train.GradientDescentOptimizer(learning_rate=0.0001).minimize(cost, name='optimizer')和更多代码开始训练过程。

我可以保存该模型,然后在另一个文件中继续训练而不必重新创建模型吗?

我想做类似的事情:

  • 在新文件中,加载模型
  • 装有已加载的模型火车。
  • 也许在某个时间点进行推断。

修改

我的直觉告诉我这不可能。这就是我要做的:

  • 使用tf.train.Saver保存模型
  • 在另一个地方,使用tf.train.Saver加载模型
  • 创建一个新的优化器以优化模型中的成本,然后再次训练。

1 个答案:

答案 0 :(得分:1)

是的,那完全有可能。 Full TutorialDocumentation

保存:

Tensorflow变量仅在会话内有效。因此,您必须通过在saver对象上调用save方法将模型保存在会话中。

import tensorflow as tf
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model')

要在1000次迭代后保存模型,请通过传递步数来调用save:

saver.save(sess, 'my_test_model',global_step=1000)

要使用预先训练的模型进行微调,请执行以下操作:

with tf.Session() as sess:    
  saver = tf.train.import_meta_graph('my-model-1000.meta')
  saver.restore(sess,tf.train.latest_checkpoint('./'))
  print(sess.run('w1:0'))
  ##Model has been restored. Above statement will print the saved value of w1.

要通过添加更多层然后对其进行训练来添加更多操作,

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

#Add more to the current graph
add_on_op = tf.multiply(op_to_restore,2)

print sess.run(add_on_op,feed_dict)
#This will print 120.