我已经成功保存了一个简单的MNIST,成为下一个代码:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
import tensorflow as tf
sess = tf.InteractiveSession()
tf_save_file = './mnist-to-save-saved'
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
y = tf.matmul(x, W) + b
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = y_, logits = y))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
saver.save(sess, tf_save_file)
for _ in range(1000):
batch = mnist.train.next_batch(100)
train_step.run(feed_dict={x: batch[0], y_: batch[1]})
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
saver.save(sess, tf_save_file, global_step=1000)
print(accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
然后,下一个文件 生成:
checkpoint
mnist-to-save-saved-1000.data-00000-of-00001
mnist-to-save-saved-1000.index
mnist-to-save-saved-1000.meta
mnist-to-save-saved.data-00000-of-00001
mnist-to-save-saved.index
mnist-to-save-saved.meta
现在,为了在生产中使用它(例如,传递一个数字图像),我希望能够通过传递> 执行训练模型 任何数字图像进行预测(我的意思是,不是部署服务器而是进行预测" 本地",在同一目录"修复"编号图像,因此使用模型就像运行可执行文件一样。)
但是,考虑到我的代码的(中 - 低?)API级别,我对最简单正确的下一步(如果使用Estimator进行恢复等)感到困惑,以及如何这样做。
虽然我已经阅读了官方文档,但我坚持认为它们似乎有很多种方式,但有些方面有点复杂且“吵闹”#34;对于像这样的简单模型。
修改
我编辑并重新运行mnist文件,其代码与上述相同,但这些行除外:
...
x = tf.placeholder(tf.float32, shape=[None, 784], name='input')
...
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1), name='result')
...
然后,我尝试运行另一个.py代码(与上面代码在同一目录中),以便传递位于本地的手写数字图像(" mnist-input-image.png")在同一目录中:
import tensorflow as tf
from PIL import Image
import numpy as np
image_test = Image.open("mnist-input-image.png")
image = np.array(image_test)
with tf.Session() as sess:
saver = tf.train.import_meta_graph('/Users/username/.meta')
new = saver.restore(sess, tf.train.latest_checkpoint('/Users/username/'))
graph = tf.get_default_graph()
input_x = graph.get_tensor_by_name("input:0")
result = graph.get_tensor_by_name("result:0")
feed_dict = {input_x: image}
predictions = result.eval(feed_dict=feed_dict)
print(predictions)
现在,如果我正确理解,我将图像作为numpy数组传递。然后,我的问题是:
1)这些行的确切文件引用是什么(因为我的用户文件夹中没有.meta文件夹)?
saver = tf.train.import_meta_graph('/Users/username/.meta')
new = saver.restore(sess, tf.train.latest_checkpoint('/Users/username/'))
我的意思是,哪些确切的文件引用这些行(来自我上面生成的文件列表)?
2)转换到我的情况,这行是否正确将我的numpy数组传递给feed dict?
feed_dict = {input_x: image}
答案 0 :(得分:0)
一个简单的解决方案是使用会话对象。生成checkpoint
文件后,可以使用Saver
对象恢复它。
顺便问一下,你知道为什么大多数教程都在函数内部创建了图形吗?一个很好的理由是因为您可以使用输入快速反序列化图形。
启动会话的正确方法如下:
# Use your placeholders, variables, etc to create the entire graph.
# Usually you return the input placeholder,
# prediction and the loss/accuracy here.
# You don't need the accuracy.
x, y, _ = make_your_graph(test_X, test_y)
# This object is the interface for serialization in tf
saver = tf.train.Saver()
with tf.Session() as sess:
# Takes your current model's checkpoint. "./checkpoint" is your checkpoint file.
saver.restore(sess, tf.train.latest_checkpoint("./checkpoint"))
prediction = sess.run(y)
想要为已经启动的会话运行多个数据点吗?
然后用feed dict替换最后一行:
while waiting_for_new_y():
another_y = get_new_y()
feed_dict = {x: [another_y]}
another_prediction = sess.run(y, feed_dict)
答案 1 :(得分:0)
首先,在稍后要使用的每个对象中为name参数赋值,以便稍后可以使用它的名称:
改变这个:
x = tf.placeholder(tf.float32, shape=[None, 784])
到
x = tf.placeholder(tf.float32, shape=[None, 784],name='input')
和
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
到
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1),name='result')
现在运行这个小脚本来存储模型:
import tensorflow as tf
with tf.Session() as sess:
saver = tf.train.import_meta_graph('/Users/dummy/.meta')
new=saver.restore(sess, tf.train.latest_checkpoint('/Users/dummy/'))
graph = tf.get_default_graph()
input_x = graph.get_tensor_by_name("input:0")
result = graph.get_tensor_by_name("result:0")
feed_dict = {input_x: mnist.test.images,} #here you feed your new data for example i am feeding mnist
predictions = result.eval(feed_dict=feed_dict)
print(predictions)
你会得到输出。