如何使用Tensorflow服务提供再训练的初始模型?

时间:2016-05-15 11:51:57

标签: tensorflow tensorflow-serving

所以我根据本指南训练了初始模型以识别花朵。 https://www.tensorflow.org/versions/r0.8/how_tos/image_retraining/index.html

android {
   buildTypes {
      debug {
         testCoverageEnabled = true
      }
   }
}

要通过命令行对图像进行分类,我可以这样做:

bazel build tensorflow/examples/image_retraining:retrain
bazel-bin/tensorflow/examples/image_retraining/retrain --image_dir ~/flower_photos

但是如何通过Tensorflow服务来提供此图表?

关于设置Tensorflow服务(https://tensorflow.github.io/serving/serving_basic)的指南并未说明如何合并图形(output_graph.pb)。服务器需要不同格式的文件:

bazel build tensorflow/examples/label_image:label_image && \
bazel-bin/tensorflow/examples/label_image/label_image \
--graph=/tmp/output_graph.pb --labels=/tmp/output_labels.txt \
--output_layer=final_result \
--image=$HOME/flower_photos/daisy/21652746_cc379e0eea_m.jpg

3 个答案:

答案 0 :(得分:2)

要在训练后提供图表,您需要使用此API导出图片:https://www.tensorflow.org/versions/r0.8/api_docs/python/train.html#export_meta_graph

该api生成服务代码所需的元图def(这将生成您要询问的.meta文件)

此外,您需要使用Saver.save()恢复检查点,Saver.save()是Saver类https://www.tensorflow.org/versions/r0.8/api_docs/python/train.html#Saver

完成此操作后,您将同时使用metagraph def和恢复图形所需的检查点文件。

答案 1 :(得分:1)

您必须导出模型。我有一个PR在重新训练期间输出模型。它的要点如下:

import tensorflow as tf

def export_model(sess, architecture, saved_model_dir):
  if architecture == 'inception_v3':
    input_tensor = 'DecodeJpeg/contents:0'
  elif architecture.startswith('mobilenet_'):
    input_tensor = 'input:0'
  else:
    raise ValueError('Unknown architecture', architecture)
  in_image = sess.graph.get_tensor_by_name(input_tensor)
  inputs = {'image': tf.saved_model.utils.build_tensor_info(in_image)}

  out_classes = sess.graph.get_tensor_by_name('final_result:0')
  outputs = {'prediction': tf.saved_model.utils.build_tensor_info(out_classes)}

  signature = tf.saved_model.signature_def_utils.build_signature_def(
    inputs=inputs,
    outputs=outputs,
    method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
  )

  legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')

  # Save out the SavedModel.
  builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir)
  builder.add_meta_graph_and_variables(
    sess, [tf.saved_model.tag_constants.SERVING],
    signature_def_map={
      tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature
    },
    legacy_init_op=legacy_init_op)
  builder.save()

上面将创建一个变量目录和saved_model.pb文件。如果你把它放在代表版本号的父目录下(例如1 /),那么你可以通过以下方式调用tensorflow:

tensorflow_model_server --port=9000 --model_name=inception --model_base_path=/path/to/saved_models/

答案 2 :(得分:0)

查看这个要点如何在会话中加载.pb输出图:

https://github.com/eldor4do/Tensorflow-Examples/blob/master/retraining-example.py