我在https://github.com/tensorflow/models/tree/master/slim的自述文件之后为花卉图像训练了一个inception_resnet_v2模型
我在训练后得到了我的graph.pbtxt文件,我使用以下代码将其转换为graph.pb文件:
import tensorflow as tf
from google.protobuf import text_format
def convert_pbtxt_to_graphdef(filename):
"""Returns a `tf.GraphDef` proto representing the data in the given pbtxt file.
Args:
filename: The name of a file containing a GraphDef pbtxt (text-formatted
`tf.GraphDef` protocol buffer data).
Returns:
A `tf.GraphDef` protocol buffer.
"""
with tf.gfile.FastGFile(filename, 'r') as f:
graph_def = tf.GraphDef()
file_content = f.read()
# Merges the human-readable string in `file_content` into `graph_def`.
text_format.Merge(file_content, graph_def)
return graph_def
with tf.gfile.FastGFile('/foo/bar/workspace/results/graph.pb', 'wb') as f:
f.write(convert_pbtxt_to_graphdef('/foo/bar/workspace/results/graph.pbtxt'))
获取此文件后,我尝试使用tensorflow的classify_image.py为训练模型提供随机图像:
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/image/imagenet/classify_image.py
使用我的.pb,.pbtxt和我的标签文件,但是,我收到以下错误:
Traceback (most recent call last):
File "classify_image.py", line 212, in <module>
tf.app.run()
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/platform/app.py", line 30, in run
sys.exit(main(sys.argv[:1] + flags_passthrough))
File "classify_image.py", line 208, in main
run_inference_on_image(image)
File "classify_image.py", line 170, in run_inference_on_image
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2615, in get_tensor_by_name
return self.as_graph_element(name, allow_tensor=True, allow_operation=False)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2466, in as_graph_element
return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2508, in _as_graph_element_locked
"graph." % (repr(name), repr(op_name)))
KeyError: "The name 'softmax:0' refers to a Tensor which does not exist. The operation, 'softmax', does not exist in the graph."
答案 0 :(得分:6)
苗条的,实际上是张量流/模型的问题在于,框架等所生成的模型并不适合预测用例:
TF-slim是TensorFlow的一个新的轻量级高级API (tensorflow.contrib.slim)用于定义,培训和评估 复杂的模型。 (来源https://github.com/tensorflow/models/tree/master/slim)
预测时的主要问题是,它只适用于Saver类创建的检查点文件。使用检查点文件时,assign_from_checkpoint_fn()方法可用于使用检查点中包含的训练参数初始化所有变量。另一方面,在只有GraphDef文件* .pb的情况下,你有点丢失。虽然有一个很好的技巧。
关键思想是在将训练过的模型保存为检查点后,将输入图像的tf.placeholder变量注入计算图。以下脚本( convert_checkpoint_to_pb.py )读取检查点,插入占位符,将图形变量转换为常量并将其转储到* .pb文件。
import tensorflow as tf
from tensorflow.contrib import slim
from nets import inception
from tensorflow.python.framework.graph_util import convert_variables_to_constants
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference
from preprocessing import inception_preprocessing
checkpoints_dir = '/path/to/your/checkpoint_dir/'
OUTPUT_PB_FILENAME = 'minimal_graph.proto'
NUM_CLASSES = 2
# We need default size of image for a particular network.
# The network was trained on images of that size -- so we
# resize input image later in the code.
image_size = inception.inception_resnet_v2.default_image_size
with tf.Graph().as_default():
# Inject placeholder into the graph
input_image_t = tf.placeholder(tf.string, name='input_image')
image = tf.image.decode_jpeg(input_image_t, channels=3)
# Resize the input image, preserving the aspect ratio
# and make a central crop of the resulted image.
# The crop will be of the size of the default image size of
# the network.
# I use the "preprocess_for_eval()" method instead of "inception_preprocessing()"
# because the latter crops all images to the center by 85% at
# prediction time (training=False).
processed_image = inception_preprocessing.preprocess_for_eval(image,
image_size,
image_size, central_fraction=None)
# Networks accept images in batches.
# The first dimension usually represents the batch size.
# In our case the batch size is one.
processed_images = tf.expand_dims(processed_image, 0)
# Load the inception network structure
with slim.arg_scope(inception.inception_resnet_v2_arg_scope()):
logits, _ = inception.inception_resnet_v2(processed_images,
num_classes=NUM_CLASSES,
is_training=False)
# Apply softmax function to the logits (output of the last layer of the network)
probabilities = tf.nn.softmax(logits)
model_path = tf.train.latest_checkpoint(checkpoints_dir)
# Get the function that initializes the network structure (its variables) with
# the trained values contained in the checkpoint
init_fn = slim.assign_from_checkpoint_fn(
model_path,
slim.get_model_variables())
with tf.Session() as sess:
# Now call the initialization function within the session
init_fn(sess)
# Convert variables to constants and make sure the placeholder input_image is included
# in the graph as well as the other neccesary tensors.
constant_graph = convert_variables_to_constants(sess, sess.graph_def, ["input_image", "DecodeJpeg",
"InceptionResnetV2/Logits/Predictions"])
# Define the input and output layer properly
optimized_constant_graph = optimize_for_inference(constant_graph, ["eval_image"],
["InceptionResnetV2/Logits/Predictions"],
tf.string.as_datatype_enum)
# Write the production ready graph to file.
tf.train.write_graph(optimized_constant_graph, '.', OUTPUT_PB_FILENAME, as_text=False)
(模型/超薄代码必须在您的python路径中才能执行此代码)
要使用转换后的模型预测新图像(现在显示为* .pb文件),请使用文件 minimal_predict.py 中的代码:
import tensorflow as tf
import urllib2
def create_graph(model_file):
"""Creates a graph from saved GraphDef file and returns a saver."""
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(model_file, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
model_file = "/your/path/to/minimal_graph.proto"
url = ("http://pictureparadise.net/funny-babies/funny-babies02/funny-babies-053.jpg")
# Open specified url and load image as a string
image_string = urllib2.urlopen(url).read()
with tf.Graph().as_default():
with tf.Session() as new_sess:
create_graph(model_file)
softmax = new_sess.graph.get_tensor_by_name("InceptionResnetV2/Logits/Predictions:0")
# Loading the injected placeholder
input_placeholder = new_sess.graph.get_tensor_by_name("input_image:0")
probabilities = new_sess.run(softmax, {input_placeholder: image_string})
print probabilities
要使用这些脚本,只需运行python
python convert_checkpoint_to_pb.py
python minimal_predicty.py
虽然你的PYTHONPATH中有tensorflow和tensorflow / models / slim。
答案 1 :(得分:0)
在@Maximilian的convert_checkpoint_to_pb.pyprovided中,[&#34; eval_image&#34;]在
中optimized_constant_graph = optimize_for_inference(constant_graph, ["eval_image"],["InceptionResnetV2/Logits/Predictions"],tf.string.as_datatype_enum)
应替换为[&#34; input_image&#34;]否则您将收到错误说明&#34;未找到以下输入节点:{&#39; eval_image&#39;} \ n&#34 ;