如何根据训练有素的Tensorflow模型进行预测?

时间:2017-09-16 07:51:02

标签: python python-3.x tensorflow neural-network

我在Stackoverflow上检查了很多关于我的问题的问题,但我仍然遇到问题。

我已按照Deep MNIST for Experts中的教程使用mnist_deep.py中的代码,并使用resx将模型保存到磁盘。

在我的predict.py中,我使用tf.saved_model.builder.SavedModelBuilder()加载模型,在加载模型后,基于我在Google上搜索的大量搜索,我知道我必须运行{{1}要进行预测,我也知道变量tf.saved_model.loader.load(),它应该是最后一层,对于sess.run(y_, feed_dict={x: test_data})中的'x',它应该是训练中输入的占位符。 / p>

我的问题是,我不知道哪个代码属于mnist_deep.py中的最后一层。

我的mnist_deep.py代码如下:

y

这是我的predict.py:

feed_dict

1 个答案:

答案 0 :(得分:0)

您没有运行y_,它是占位符。

加载模型后,您也不会重新定义变量,而是使用保存的变量。

因此,加载后,只需运行sess.run(y_conv, feed_dict={x: test_data})

y_conv是模型最后一层的预测输出。

要在加载模型后访问y_conv,请通过以下方式获取:{ y_conv = sess.graph.get_tensor_by_name("it's name goes here")
您需要在保存之前命名y_conv

或者,您可以在保存前将y_conv添加到集合中,然后在加载模型后从集合中检索它:
tf.add_to_collection('vars', y_conv)
y_conv = tf.get_collection('vars')[0]