我根据recent post by Google’s Derek Chow on the Google Cloud Big Data And Machine Learning Blog使用Cloud Machine Learning Engine训练了一个物体探测器,现在想要使用Cloud Machine Learning Engine进行预测。
说明包括将Tensorflow图导出为output_inference_graph.pb的代码,但不包括如何将protobuf格式(pb)转换为gcloud ml-engine预测所需的SavedModel格式。
我查看了answer by Google’s @rhaertel80有关如何转换“Tensorflow For Poets”图像分类模型和answer provided by Google’s @MarkMcDonald有关如何转换“Tensorflow For Poets 2”图像分类模型但似乎都没有工作对于博客文章中描述的对象检测器图(pb)。
如何转换该对象检测器图(pb)以便可以使用它或gcloud ml-engine预测,请?
答案 0 :(得分:2)
SavedModel在其MetaGraphDef内包含structure。 要从python中的GraphDef创建SavedModel,您可能希望使用链接中描述的构建器。
export_dir = ...
...
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
with tf.Session(graph=tf.Graph()) as sess:
...
builder.add_meta_graph_and_variables(sess,
[tag_constants.TRAINING],
signature_def_map=foo_signatures,
assets_collection=foo_assets)
...
with tf.Session(graph=tf.Graph()) as sess:
...
builder.add_meta_graph(["bar-tag", "baz-tag"])
...
builder.save()
答案 1 :(得分:0)
这篇文章救了我!希望能帮助到这里来的人们。我使用成功导出的方法https://stackoverflow.com/a/48102615/6124383
https://github.com/tensorflow/tensorflow/pull/15855/commits/81ec5d20935352d71ff56fac06c36d6ff0a7ae05
def export_model(sess, architecture, saved_model_dir):
if architecture == 'inception_v3':
input_tensor = 'DecodeJpeg/contents:0'
elif architecture.startswith('mobilenet_'):
input_tensor = 'input:0'
else:
raise ValueError('Unknown architecture', architecture)
in_image = sess.graph.get_tensor_by_name(input_tensor)
inputs = {'image': tf.saved_model.utils.build_tensor_info(in_image)}
out_classes = sess.graph.get_tensor_by_name('final_result:0')
outputs = {'prediction': tf.saved_model.utils.build_tensor_info(out_classes)}
signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs=inputs,
outputs=outputs,
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
)
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
# Save out the SavedModel.
builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature
},
legacy_init_op=legacy_init_op)
builder.save()
#execute this in the final of def main(_):
export_model(sess, FLAGS.architecture, FLAGS.saved_model_dir)
parser.add_argument(
'--saved_model_dir',
type=str,
default='/tmp/saved_models/1/',
help='Where to save the exported graph.'
)