如何在TensorFlow中导入模型

时间:2017-05-10 08:38:38

标签: python tensorflow

我在恢复已保存的模型时遇到了困难。我正在MNIST数据集上训练CNN,根据Deep MNIST for Experts上的MNIST教程,我使用以下代码保存我的模型:

saver.save(sess, './Tensorflow_MNIST', global_step=max_steps)

这将创建以下文件:

  • Tensorflow_MNIST-1000.data-00000-的-00001
  • Tensorflow_MNIST-1000.index
  • Tensorflow_MNIST-1000.meta
  • 检查点

稍后我想加载模型并继续训练:

with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('./Tensorflow_MNIST-1000.meta')
new_saver.restore(sess, './Tensorflow_MNIST-1000')

batch_xs, batch_ys = mnist.train.next_batch(50)
sess.run(train_step, feed_dict[x: batch_xs, y_batch_ys, keep_prob:0.5])

然而,这会返回错误:

NameError: name 'train_step' is not defined

因此,图表及其变量和操作似乎未正确加载。我在这里做错了什么?

3 个答案:

答案 0 :(得分:2)

保存时

saver = tf.train.Saver(...variables...)
# Remember the training_op we want to run by adding it to a collection.
tf.add_to_collection('train_step', train_step)

恢复时:

with tf.Session() as sess:
  ....

  # tf.get_collection() returns a list. get the first one
  train_step = tf.get_collection('train_step')[0]
  sess.run(train_step, ....)

如果您想重新使用该模型,我认为将sess.run(train_step...)更改为 train_step(...)应该有效

答案 1 :(得分:1)

当使用saver.save()时,TensorFlow保存计算图,该图由张量(即TensorFlow对象)组成。

它不会保存您使用的每个变量。特别是,任何 tf.Tensor的内容都不会被保存。

您可能希望拥有自己的数据结构来保存任何其他信息。

为方便起见,您可以使用JSON格式,甚至可以在python中使用pickle,但不能手动编辑。

希望有所帮助

答案 2 :(得分:1)

用""呼叫所有张量并且:import meta_graph中描述的添加0似乎可以解决问题。因此,例如,计算准确度的调用变为:

test_accuracy = sess.run("accuracy:0", feed_dict={"x:0": mnist.test.images, "y_:0": mnist.test.labels, "keep_prob:0": 1.0})