如何恢复Tensorflow模型

时间:2017-03-28 09:02:24

标签: python tensorflow

我正在学习使用Tensorflow,我编写了这个从mnist db学习的Python脚本,保存模型并对图像进行预测:

X = tf.placeholder(tf.float32, [None, 28, 28, 1])
W = tf.Variable(tf.zeros([784, 10], name="W"))
b = tf.Variable(tf.zeros([10]), name="b")
Y = tf.nn.softmax(tf.matmul(tf.reshape(X, [-1, 784]), W) + b)
# ...
init = tf.global_variables_initializer()

saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(init)

    # ... learning loop

    saver.save(sess, "/tmp/my-model")

    # Make a prediction with an image
    im = numpy.asarray(Image.open("digit.png")) / 255
    im = im[numpy.newaxis, :, :, numpy.newaxis]
    dict = {X: im}
    print("Prediction: ", numpy.array(sess.run(Y, dict)).argmax())

预测是正确的,但我无法恢复已保存的模型以便重复使用。 我写了另一个试图恢复模型并做出相同预测的脚本:

X = tf.placeholder(tf.float32, [None, 28, 28, 1])
W = tf.Variable(tf.zeros([784, 10]), name="W")
b = tf.Variable(tf.ones([10]) / 10, name="b")
Y = tf.nn.softmax(tf.matmul(tf.reshape(X, [-1, 784]), W) + b)

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    saver = tf.train.import_meta_graph('/tmp/my-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('/tmp/'))

    # Make a prediction with an image
    im = numpy.asarray(Image.open("digit.png")) / 255
    im = im[numpy.newaxis, :, :, numpy.newaxis]
    dict = {X: im}
    print("Prediction: ", numpy.array(sess.run(Y, dict)).argmax())

但预测错了。 如何恢复变量并进行预测? 感谢

1 个答案:

答案 0 :(得分:1)

测试时,请注释此行

# saver = tf.train.import_meta_graph('/tmp/my-model.meta')

将解决您的问题。

import_meta_graph将创建一个新的图表/模型,保存在' .meta'文件和新模型将与您手动创建的模型共存。 saver已分配给新模型,因此saver.restore会将训练后的权重恢复为新模型,但sess会使用您手动创建的模型运行。