“在TensorFlow服务

时间:2016-11-28 16:44:34

标签: tensorflow tensorflow-serving

我在TensorFlow服务中运行示例虹膜程序。由于它是TF.Learn模型,我使用以下classifier.export(export_dir=model_dir,signature_fn=my_classification_signature_fn)导出模型,并且signature_fn的定义如下所示:

def my_classification_signature_fn(examples, unused_features, predictions):
  """Creates classification signature from given examples and predictions.
  Args:
    examples: `Tensor`.
    unused_features: `dict` of `Tensor`s.
    predictions: `Tensor` or dict of tensors that contains the classes tensor
      as in {'classes': `Tensor`}.
  Returns:
    Tuple of default classification signature and empty named signatures.
  Raises:
    ValueError: If examples is `None`.
  """
  if examples is None:
    raise ValueError('examples cannot be None when using this signature fn.')

  if isinstance(predictions, dict):
    default_signature = exporter.classification_signature(
        examples, classes_tensor=predictions['classes'])

  else:

    default_signature = exporter.classification_signature(
        examples, classes_tensor=predictions)
  named_graph_signatures={
        'inputs': exporter.generic_signature({'x_values': examples}),
        'outputs': exporter.generic_signature({'preds': predictions})}    
  return default_signature, named_graph_signatures

使用以下代码成功导出模型。

我创建了一个使用TensorFlow服务进行实时预测的客户端。

以下是客户端的代码:

flags.DEFINE_string("model_dir", "/tmp/iris_model_dir", "Base directory for output models.")
tf.app.flags.DEFINE_integer('concurrency', 1,
                            'maximum number of concurrent inference requests')
tf.app.flags.DEFINE_string('server', '', 'PredictionService host:port')

#connection
host, port = FLAGS.server.split(':')
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)


# Classify two new flower samples.
new_samples = np.array([5.8, 3.1, 5.0, 1.7], dtype=float)

request = predict_pb2.PredictRequest()
request.model_spec.name = 'iris'

request.inputs["x_values"].CopyFrom(
        tf.contrib.util.make_tensor_proto(new_samples))

result = stub.Predict(request, 10.0)  # 10 secs timeout

但是,在进行预测时,会显示以下错误:

grpc.framework.interfaces.face.face.AbortionError: AbortionError(code=StatusCode.INTERNAL, details="Output 0 of type double does not match declared output type string for node _recv_input_example_tensor_0 = _Recv[client_terminated=true, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=2016246895612781641, tensor_name="input_example_tensor:0", tensor_type=DT_STRING, _device="/job:localhost/replica:0/task:0/cpu:0"]()")

这是整个堆栈跟踪。

enter image description here

虹膜模型以下列方式定义:

# Specify that all features have real-value data
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

# Build 3 layer DNN with 10, 20, 10 units respectively.
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                            hidden_units=[10, 20, 10],
                                            n_classes=3, model_dir=model_dir)

# Fit model.
classifier.fit(x=training_set.data, 
               y=training_set.target, 
               steps=2000)

请为此错误指导解决方案。

1 个答案:

答案 0 :(得分:0)

我认为问题在于你的signature_fn正在进入else分支并将预测作为输出传递给分类签名,这需要字符串输出而不是双输出。使用回归签名函数或向图表添加内容以获得字符串形式的输出。