我在Stackoverflow上检查了很多关于我的问题的问题,但我仍然遇到问题。
我已按照Deep MNIST for Experts中的教程使用mnist_deep.py中的代码,并使用resx
将模型保存到磁盘。
在我的predict.py中,我使用tf.saved_model.builder.SavedModelBuilder()
加载模型,在加载模型后,基于我在Google上搜索的大量搜索,我知道我必须运行{{1}要进行预测,我也知道变量tf.saved_model.loader.load()
,它应该是最后一层,对于sess.run(y_, feed_dict={x: test_data})
中的'x',它应该是训练中输入的占位符。 / p>
我的问题是,我不知道哪个代码属于mnist_deep.py中的最后一层。
我的mnist_deep.py代码如下:
y
这是我的predict.py:
feed_dict
答案 0 :(得分:0)
您没有运行y_
,它是占位符。
加载模型后,您也不会重新定义变量,而是使用保存的变量。
因此,加载后,只需运行sess.run(y_conv, feed_dict={x: test_data})
y_conv
是模型最后一层的预测输出。
要在加载模型后访问y_conv
,请通过以下方式获取:{
y_conv = sess.graph.get_tensor_by_name("it's name goes here")
您需要在保存之前命名y_conv
。
或者,您可以在保存前将y_conv
添加到集合中,然后在加载模型后从集合中检索它:
tf.add_to_collection('vars', y_conv)
y_conv = tf.get_collection('vars')[0]