如何对保存的估算器模型进行简单的CLI查询?

时间:2018-07-06 14:02:25

标签: python tensorflow command-line

我已经成功地训练了DNNClassifier来对文本进行分类(来自在线讨论区的帖子)。我已经保存了模型,现在想使用TensorFlow CLI对文本进行分类。

为保存的模型运行saved_model_cli show时,得到以下输出:

saved_model_cli show --dir /my/model --tag_set serve --signature_def predict
The given SavedModel SignatureDef contains the following input(s):
  inputs['examples'] tensor_info:
      dtype: DT_STRING
      shape: (-1)
      name: input_example_tensor:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['class_ids'] tensor_info:
      dtype: DT_INT64
      shape: (-1, 1)
      name: dnn/head/predictions/ExpandDims:0
  outputs['classes'] tensor_info:
      dtype: DT_STRING
      shape: (-1, 1)
      name: dnn/head/predictions/str_classes:0
  outputs['logistic'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1, 1)
      name: dnn/head/predictions/logistic:0
  outputs['logits'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1, 1)
      name: dnn/logits/BiasAdd:0
  outputs['probabilities'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1, 2)
      name: dnn/head/predictions/probabilities:0
Method name is: tensorflow/serving/predict

我无法找出saved_model_cli run的正确参数来获得预测。

我尝试了几种方法,例如:

saved_model_cli run --dir /my/model --tag_set serve --signature_def predict --input_exprs='examples=["klassifiziere mich bitte"]'

哪个给我这个错误消息:

InvalidArgumentError (see above for traceback): Could not parse example input, value: 'klassifiziere mich bitte'
 [[Node: ParseExample/ParseExample = ParseExample[Ndense=1, Nsparse=0, Tdense=[DT_STRING], dense_shapes=[[1]], sparse_types=[], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_input_example_tensor_0_0, ParseExample/ParseExample/names, ParseExample/ParseExample/dense_keys_0, ParseExample/ParseExample/names)]]

将输入字符串传递到CLI以获得分类的正确方法是什么?

您可以在GitHub上找到我的项目代码,包括培训数据:https://github.com/pahund/beitragstuev

我正在建立和保存这样的模型(简化为see GitHub for original code):

embedded_text_feature_column = hub.text_embedding_column(
    key="sentence",
    module_spec="https://tfhub.dev/google/nnlm-de-dim128/1")
feature_columns = [embedded_text_feature_column]
estimator = tf.estimator.DNNClassifier(
    hidden_units=[500, 100],
    feature_columns=feature_columns,
    n_classes=2,
    optimizer=tf.train.AdagradOptimizer(learning_rate=0.003))
feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)
estimator.export_savedmodel(export_dir_base="/my/dir/base", serving_input_receiver_fn=serving_input_receiver_fn)

2 个答案:

答案 0 :(得分:5)

您为模型导出创建的dfs<-list() for (i in 1:1000) { dfs[[i]]<-iris[sample(1:length(iris$Sepal.Length),80),-5] } 告诉保存的模型期望序列化的ServingInputReceiver原型,而不是您希望分类的原始字符串。

来自the Save and Restore documentation

  

一个典型的模式是推理请求以序列化的tf.Examples的形式到达,因此serving_input_receiver_fn()创建一个单个字符串占位符来接收它们。然后,serving_input_receiver_fn()还负责解析tf。示例通过向图表中添加tf.parse_example op。

     

....

     

tf.estimator.export.build_parsing_serving_input_receiver_fn实用程序功能为常见情况提供了该输入接收器。

因此,您导出的模型包含一个tf.parse_example操作,该操作期望接收满足您传递给build_parsing_serving_input_receiver_fn的功能规范的序列化tf.Example原型,即在您的情况下,它期望序列化的示例具有tf.Example功能。要使用该模型进行预测,您必须提供这些序列化的原型。

幸运的是,Tensorflow使构建它们变得相当容易。这是一个可能的函数,该表达式返回将sentence输入键映射到一批字符串的表达式,然后可以将其传递到CLI:

examples

因此,请使用从示例中提取的一些字符串:

import tensorflow as tf

def serialize_example_string(strings):

  serialized_examples = []
  for s in strings:
    try:
      value = [bytes(s, "utf-8")]
    except TypeError:  # python 2
      value = [bytes(s)]

    example = tf.train.Example(
                features=tf.train.Features(
                  feature={
                    "sentence": tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
                  }
                )
              )
    serialized_examples.append(example.SerializeToString())

  return "examples=" + repr(serialized_examples).replace("'", "\"")

CLI命令为:

strings = ["klassifiziere mich bitte",
           "Das Paket „S Line Competition“ umfasst unter anderem optische Details, eine neue Farbe (Turboblau), 19-Zöller und LED-Lampen.",
           "(pro Stimme geht 1 Euro Spende von Pfuscher ans Forum) ah du sack, also so gehts ja net :D:D:D"]

print (serialize_example_string(strings))

应该会为您提供所需的结果:

saved_model_cli run --dir /path/to/model --tag_set serve --signature_def predict --input_exprs='examples=[b"\n*\n(\n\x08sentence\x12\x1c\n\x1a\n\x18klassifiziere mich bitte", b"\n\x98\x01\n\x95\x01\n\x08sentence\x12\x88\x01\n\x85\x01\n\x82\x01Das Paket \xe2\x80\x9eS Line Competition\xe2\x80\x9c umfasst unter anderem optische Details, eine neue Farbe (Turboblau), 19-Z\xc3\xb6ller und LED-Lampen.", b"\np\nn\n\x08sentence\x12b\n`\n^(pro Stimme geht 1 Euro Spende von Pfuscher ans Forum) ah du sack, also so gehts ja net :D:D:D"]'

答案 1 :(得分:2)

或者,save_model_cli提供了另一个选项someObservable.pipe(map(mappingLogicMethod)).subscribe(x => console.log(x)); 而不是CREATE DATABASE databasename; ,因此您可以直接在cmd行中传递tf.Examples数据,而无需手动序列化。

例如:

--input_examples

有关详细信息,请参见https://www.tensorflow.org/guide/saved_model#--input_examples