如何为“广泛而深入”的客户创建一个数据流服务客户端。模型?

时间:2017-01-17 12:20:29

标签: java tensorflow deep-learning tensorflow-serving

我已经创建了一个基于“广泛而深入”的模型。示例(https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/learn/wide_n_deep_tutorial.py)。

我已按如下方式导出模型:

  m = build_estimator(model_dir)
  m.fit(input_fn=lambda: input_fn(df_train, True), steps=FLAGS.train_steps)
  results = m.evaluate(input_fn=lambda: input_fn(df_test, True), steps=1)

  print('Model statistics:')

  for key in sorted(results):
    print("%s: %s" % (key, results[key]))

  print('Done training!!!')

  # Export model
  export_path = sys.argv[-1]
  print('Exporting trained model to %s' % export_path)

  m.export(
   export_path,
   input_fn=serving_input_fn,
   use_deprecated_input_fn=False,
   input_feature_key=INPUT_FEATURE_KEY

我的问题是,如何创建客户端以从此导出的模型进行预测?我是否正确导出了模型?

最终我需要能够在Java中做到这一点。我怀疑我可以通过使用gRPC从proto文件创建Java类来实现这一点。

文档非常粗略,因此我在这里问。

非常感谢!

2 个答案:

答案 0 :(得分:2)

我写了一个简单的教程Exporting and Serving a TensorFlow Wide & Deep Model

TL; DR

要导出估算器,有四个步骤:

  1. 将导出功能定义为估算器初始化期间使用的所有功能的列表。

  2. 使用create_feature_spec_for_parsing创建功能配置。

  3. 使用serving_input_fn构建适合使用的input_fn_utils.build_parsing_serving_input_fn

  4. 使用export_savedmodel()导出模型。

  5. 要正确运行客户端脚本,您需要执行以下三个步骤:

    1. 创建并将脚本放在/ serve /文件夹中的某个位置,例如/服务/ tensorflow_serving /示例/

    2. 通过添加py_binary

    3. 来创建或修改相应的BUILD文件
    4. 构建并运行模型服务器,例如tensorflow_model_server

    5. 创建,构建并运行一个客户端,该客户端将tf.Example发送到我们的tensorflow_model_server进行推断。

    6. 有关详细信息,请查看教程本身。

答案 1 :(得分:1)

花了一个坚实的一周来搞清楚这一点。首先,m.export将在几周后弃用,因此请使用:m.export_savedmodel(export_path, input_fn=serving_input_fn)而不是该块。

这意味着您必须定义serving_input_fn(),当然这应该与广泛深入的教程中定义的input_fn()具有不同的签名。也就是说,向前推进,我想建议input_fn() - 类型的东西应该返回一个InputFnOps对象,定义为here

以下是我如何确定如何完成这项工作的方法:

from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.framework import dtypes

def serving_input_fn():
  features, labels = input_fn()
  features["examples"] = tf.placeholder(tf.string)

  serialized_tf_example = array_ops.placeholder(dtype=dtypes.string,
                                                shape=[None],
                                                name='input_example_tensor')
  inputs = {'examples': serialized_tf_example}
  labels = None  # these are not known in serving!
  return input_fn_utils.InputFnOps(features, labels, inputs)

这可能不是100%惯用,但我很确定它有效。现在。