保存TensorFlow模型

时间:2018-04-16 20:09:56

标签: python tensorflow mnist

我已经成功保存了一个简单的MNIST,成为下一个代码:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
import tensorflow as tf

sess = tf.InteractiveSession()
tf_save_file = './mnist-to-save-saved'
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
saver = tf.train.Saver()

sess.run(tf.global_variables_initializer())

y = tf.matmul(x, W) + b
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = y_, logits = y))

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
saver.save(sess, tf_save_file)

for _ in range(1000):
    batch = mnist.train.next_batch(100)
    train_step.run(feed_dict={x: batch[0], y_: batch[1]})

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
saver.save(sess, tf_save_file, global_step=1000)

print(accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

然后,下一个文件 生成

checkpoint
mnist-to-save-saved-1000.data-00000-of-00001
mnist-to-save-saved-1000.index
mnist-to-save-saved-1000.meta
mnist-to-save-saved.data-00000-of-00001
mnist-to-save-saved.index
mnist-to-save-saved.meta

现在,为了在生产中使用它(例如,传递一个数字图像),我希望能够通过传递 执行训练模型 任何数字图像进行预测(我的意思是,不是部署服务器而是进行预测" 本地",在同一目录"修复"编号图像,因此使用模型就像运行可执行文件一样。)

但是,考虑到我的代码的(中 - 低?)API级别,我对最简单正确的下一步(如果使用Estimator进行恢复等)感到困惑,以及如何这样做。

虽然我已经阅读了官方文档,但我坚持认为它们似乎有很多种方式,但有些方面有点复杂且“吵闹”#34;对于像这样的简单模型。

修改

我编辑并重新运行mnist文件,其代码与上述相同,但这些行除外:

...

x = tf.placeholder(tf.float32, shape=[None, 784], name='input')

...

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1), name='result')

...

然后,我尝试运行另一个.py代码(与上面代码在同一目录中),以便传递位于本地的手写数字图像(" mnist-input-image.png")在同一目录中:

import tensorflow as tf
from PIL import Image
import numpy as np

image_test = Image.open("mnist-input-image.png")
image = np.array(image_test)

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/Users/username/.meta')
    new = saver.restore(sess, tf.train.latest_checkpoint('/Users/username/'))

    graph = tf.get_default_graph()
    input_x = graph.get_tensor_by_name("input:0")
    result = graph.get_tensor_by_name("result:0")

    feed_dict = {input_x: image}

    predictions = result.eval(feed_dict=feed_dict)
    print(predictions)

现在,如果我正确理解,我将图像作为numpy数组传递。然后,我的问题是:

1)这些行的确切文件引用是什么(因为我的用户文件夹中没有.meta文件夹)?

saver = tf.train.import_meta_graph('/Users/username/.meta')
new = saver.restore(sess, tf.train.latest_checkpoint('/Users/username/'))

我的意思是,哪些确切的文件引用这些行(来自我上面生成的文件列表)?

2)转换到我的情况,这行是否正确将我的numpy数组传递给feed dict?

feed_dict = {input_x: image}

2 个答案:

答案 0 :(得分:0)

一个简单的解决方案是使用会话对象。生成checkpoint文件后,可以使用Saver对象恢复它。

顺便问一下,你知道为什么大多数教程都在函数内部创建了图形吗?一个很好的理由是因为您可以使用输入快速反序列化图形。

启动会话的正确方法如下:

# Use your placeholders, variables, etc to create the entire graph.
# Usually you return the input placeholder, 
# prediction and the loss/accuracy here.
# You don't need the accuracy.
x, y, _ = make_your_graph(test_X, test_y)

# This object is the interface for serialization in tf
saver = tf.train.Saver()

with tf.Session() as sess:
    # Takes your current model's checkpoint. "./checkpoint" is your checkpoint file.
    saver.restore(sess, tf.train.latest_checkpoint("./checkpoint"))
    prediction = sess.run(y)

想要为已经启动的会话运行多个数据点吗?

然后用feed dict替换最后一行:

while waiting_for_new_y():
    another_y = get_new_y()
    feed_dict = {x: [another_y]}
    another_prediction = sess.run(y, feed_dict)

答案 1 :(得分:0)

首先,在稍后要使用的每个对象中为name参数赋值,以便稍后可以使用它的名称:

改变这个:

x = tf.placeholder(tf.float32, shape=[None, 784])

x = tf.placeholder(tf.float32, shape=[None, 784],name='input')

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))

correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1),name='result')

现在运行这个小脚本来存储模型:

import tensorflow as tf
with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/Users/dummy/.meta')
    new=saver.restore(sess, tf.train.latest_checkpoint('/Users/dummy/'))

    graph = tf.get_default_graph()
    input_x = graph.get_tensor_by_name("input:0")
    result = graph.get_tensor_by_name("result:0")

    feed_dict = {input_x: mnist.test.images,}  #here you feed your new data for example i am feeding mnist

    predictions = result.eval(feed_dict=feed_dict)
    print(predictions)

你会得到输出。