Tensorflow中的单一图像推理[Python]

时间:2017-08-15 17:04:19

标签: python tensorflow

我已经将预先训练好的.ckpt文件转换为.pb文件,冻结模型并保存重量。我现在要做的是使用.pb文件进行简单的推断,并提取并保存输出图像。该模型是从这里下载的(完全卷积语义分割网络):https://github.com/MarvinTeichmann/KittiSeg。到目前为止,我已设法,加载图像,设置默认的tf图并导入模型定义的图形,读取输入和输出张量并运行会话(此处出错)。

import tensorflow as tf
import os
import numpy as np
from tensorflow.python.platform import gfile
from PIL import Image

# Read the image & get statstics
img=Image.open('/path-to-image/demoImage.png')
img.show()
width, height = img.size
print(width)
print(height)

#Plot the image
#image.show()

with tf.Graph().as_default() as graph:

        with tf.Session() as sess:

                # Load the graph in graph_def
                print("load graph")

                # We load the protobuf file from the disk and parse it to retrive the unserialized graph_drf
                with gfile.FastGFile("/path-to-FCN-model/FCN8.pb",'rb') as f:

                                #Set default graph as current graph
                                graph_def = tf.GraphDef()
                                graph_def.ParseFromString(f.read())
                                #sess.graph.as_default() #new line

                                # Import a graph_def into the current default Graph
                                tf.import_graph_def(graph_def, name='')

                                # Print the name of operations in the session
                                #for op in sess.graph.get_operations():

                                    #print "Operation Name :",op.name            # Operation name
                                    #print "Tensor Stats :",str(op.values())     # Tensor name

                                # INFERENCE Here
                                l_input = graph.get_tensor_by_name('Placeholder:0')
                                l_output = graph.get_tensor_by_name('save/Assign_38:0')

                                print "l_input", l_input
                                print "l_output", l_output
                                print
                                print

                                # Acceptable feed values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles.                              
                                result = sess.run(l_output, feed_dict={l_input : img})
                                print(results)

                                print("Inference done")

                                # Info
                                # First Tensor name : Placeholder:0
                                # Last tensor name  : save/Assign_38:0"

错误可能来自图像的格式(例如,我应该将.png转换为其他格式吗?)。这是另一个根本性错误吗?

2 个答案:

答案 0 :(得分:2)

我设法修复了错误,下面是在完全卷积网络上推断单个图像的工作脚本(对于SEGNET的替代分割算法中有趣的人)。此模型使用billinear插值进行缩放而不是取消汇聚层。无论如何,因为该模型可以以.chkpt格式下载,所以必须先冻结模型并将其另存为.pb文件。稍后,您必须从TF优化器传递网络以将Dropout概率设置为1.然后,在此脚本中设置正确的输入和输出张量名称,并且推断正常工作,提取分段图像。

n

答案 1 :(得分:0)

您是否已查看demo.py。第141行显示了它们如何修改图表的输入:

# Create placeholder for input
image_pl = tf.placeholder(tf.float32)
image = tf.expand_dims(image_pl, 0)

# build Tensorflow graph using the model from logdir
prediction = core.build_inference_graph(hypes, modules,
                                        image=image)

在第164行,图片如何打开:

image = scp.misc.imread(input_image)

直接送到image_pl。唯一的一点是core.build_inference_graph是一个TensorVision调用。

请注意,提供确切的错误消息也很有趣。