我在恢复已保存的模型时遇到了困难。我正在MNIST数据集上训练CNN,根据Deep MNIST for Experts上的MNIST教程,我使用以下代码保存我的模型:
saver.save(sess, './Tensorflow_MNIST', global_step=max_steps)
这将创建以下文件:
稍后我想加载模型并继续训练:
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('./Tensorflow_MNIST-1000.meta')
new_saver.restore(sess, './Tensorflow_MNIST-1000')
batch_xs, batch_ys = mnist.train.next_batch(50)
sess.run(train_step, feed_dict[x: batch_xs, y_batch_ys, keep_prob:0.5])
然而,这会返回错误:
NameError: name 'train_step' is not defined
因此,图表及其变量和操作似乎未正确加载。我在这里做错了什么?
答案 0 :(得分:2)
:
saver = tf.train.Saver(...variables...)
# Remember the training_op we want to run by adding it to a collection.
tf.add_to_collection('train_step', train_step)
恢复时:
with tf.Session() as sess:
....
# tf.get_collection() returns a list. get the first one
train_step = tf.get_collection('train_step')[0]
sess.run(train_step, ....)
如果您想重新使用该模型,我认为将sess.run(train_step...)
更改为
train_step(...)
应该有效
答案 1 :(得分:1)
当使用saver.save()
时,TensorFlow保存计算图,该图由张量(即TensorFlow对象)组成。
它不会保存您使用的每个变量。特别是,任何不 tf.Tensor
的内容都不会被保存。
您可能希望拥有自己的数据结构来保存任何其他信息。
为方便起见,您可以使用JSON格式,甚至可以在python中使用pickle
,但不能手动编辑。
希望有所帮助
答案 2 :(得分:1)
用""呼叫所有张量并且:import meta_graph中描述的添加0似乎可以解决问题。因此,例如,计算准确度的调用变为:
test_accuracy = sess.run("accuracy:0", feed_dict={"x:0": mnist.test.images, "y_:0": mnist.test.labels, "keep_prob:0": 1.0})