如何从上次保存状态开始重新训练Tensorflow seq2seq模型?

时间:2016-11-27 00:21:07

标签: machine-learning nlp tensorflow

我对tensorflow完全不熟悉,我正在使用他们的seq2seq翻译示例。我查看了translate.py中的代码,并且训练是在无限循环中完成的,它会不时地将检查点保存在文件translate.ckpt中。

因此,如果我停止训练并希望稍后从上次保存的状态重新开始训练,我该怎么做?

由于

2 个答案:

答案 0 :(得分:0)

您需要从文件中恢复变量:

,而不是在会话中启动变量
saver = tf.train.Saver()
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
  print("Model restored.")
  # Do some work with the model

请注意,您的模型只是变量的值。要恢复它们,您需要一个具有相同变量名称的图形。并且可能需要执行操作才能计算结果。

在这里阅读更多相关信息:

https://www.tensorflow.org/versions/r0.11/how_tos/variables/index.html#restoring-variables

答案 1 :(得分:0)

我想出来,以为我应该回答它。 seq2seq示例默认执行此操作。如果您停止训练循环然后重新启动它,它将查找已保存的检查点并从上次停止的位置重新开始训练。