Tensorflow:保存和恢复模型参数

时间:2016-04-22 23:06:35

标签: tensorflow

我是TensorFlow的初学者,目前正在培训CNN。

我正在使用 Saver 来保存模型使用的参数,但我担心这是否会存储模型使用的所有变量,并且足以将值恢复为 重新运行 程序,以便在训练有素的网络上执行分类/测试。

让我们看一下TensorFlow给出的着名的MNIST示例。

在这个例子中,我们有一堆Convolutional块,所有这些块都有权重,偏差变量在程序运行时被初始化。

W_conv1 = init_weight([5,5,1,32])
b_conv1 = init_bias([32])

处理完多个图层后,我们创建一个会话,并初始化添加到图表中的所有变量。

sess = tf.Session()
sess.run(tf.initialize_all_variables())
saver = tf.train.Saver()

在这里,是否有可能对saver.save代码进行注释,并在训练后用saver.restore(sess,file_path)替换它,以便将权重,偏差等参数恢复回图形?这是怎么回事?

for i in range(1000):
 ...

  if i%500 == 0:
    saver.save(sess,"model%d.cpkt"%(i))

我目前正在接受大型数据集培训,因此终止并重新启动培训是浪费时间和资源,因此我请求有人在开始培训之前澄清。

2 个答案:

答案 0 :(得分:4)

如果您只想保存最终结果一次,可以执行以下操作:

with tf.Session() as sess:
  for i in range(1000):
    ...


  path = saver.save(sess, "model.ckpt") # out of the loop
  print "Saved:", path

在其他程序中,您可以使用saver.save返回的路径加载模型以进行预测。您可以在https://github.com/sugyan/tensorflow-mnist看到一些示例。

答案 1 :(得分:1)

根据here和Sung Kim解决方案中的解释,我为这个问题编写了一个非常简单的模型。基本上,您需要从同一个类创建一个对象,并从保护程序中恢复其变量。您可以找到此解决方案的示例here

相关问题