我已经将预先训练好的.ckpt文件转换为.pb文件,冻结模型并保存重量。我现在要做的是使用.pb文件进行简单的推断,并提取并保存输出图像。该模型是从这里下载的(完全卷积语义分割网络):https://github.com/MarvinTeichmann/KittiSeg。到目前为止,我已设法,加载图像,设置默认的tf图并导入模型定义的图形,读取输入和输出张量并运行会话(此处出错)。
import tensorflow as tf
import os
import numpy as np
from tensorflow.python.platform import gfile
from PIL import Image
# Read the image & get statstics
img=Image.open('/path-to-image/demoImage.png')
img.show()
width, height = img.size
print(width)
print(height)
#Plot the image
#image.show()
with tf.Graph().as_default() as graph:
with tf.Session() as sess:
# Load the graph in graph_def
print("load graph")
# We load the protobuf file from the disk and parse it to retrive the unserialized graph_drf
with gfile.FastGFile("/path-to-FCN-model/FCN8.pb",'rb') as f:
#Set default graph as current graph
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
#sess.graph.as_default() #new line
# Import a graph_def into the current default Graph
tf.import_graph_def(graph_def, name='')
# Print the name of operations in the session
#for op in sess.graph.get_operations():
#print "Operation Name :",op.name # Operation name
#print "Tensor Stats :",str(op.values()) # Tensor name
# INFERENCE Here
l_input = graph.get_tensor_by_name('Placeholder:0')
l_output = graph.get_tensor_by_name('save/Assign_38:0')
print "l_input", l_input
print "l_output", l_output
print
print
# Acceptable feed values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles.
result = sess.run(l_output, feed_dict={l_input : img})
print(results)
print("Inference done")
# Info
# First Tensor name : Placeholder:0
# Last tensor name : save/Assign_38:0"
错误可能来自图像的格式(例如,我应该将.png转换为其他格式吗?)。这是另一个根本性错误吗?
答案 0 :(得分:2)
我设法修复了错误,下面是在完全卷积网络上推断单个图像的工作脚本(对于SEGNET的替代分割算法中有趣的人)。此模型使用billinear插值进行缩放而不是取消汇聚层。无论如何,因为该模型可以以.chkpt格式下载,所以必须先冻结模型并将其另存为.pb文件。稍后,您必须从TF优化器传递网络以将Dropout概率设置为1.然后,在此脚本中设置正确的输入和输出张量名称,并且推断正常工作,提取分段图像。
n
答案 1 :(得分:0)
您是否已查看demo.py。第141行显示了它们如何修改图表的输入:
# Create placeholder for input
image_pl = tf.placeholder(tf.float32)
image = tf.expand_dims(image_pl, 0)
# build Tensorflow graph using the model from logdir
prediction = core.build_inference_graph(hypes, modules,
image=image)
在第164行,图片如何打开:
image = scp.misc.imread(input_image)
直接送到image_pl。唯一的一点是core.build_inference_graph是一个TensorVision调用。
请注意,提供确切的错误消息也很有趣。