导出tensorflow contrib.learn导出以导出到tensorfow_serving

时间:2016-11-09 20:40:28

标签: python tensorflow tensorflow-serving

我正在尝试为tensorflow服务导出tf.contrib.learn.DNNLinearCombinedClassifier的实例。运行以下代码时:

estimator.export(
export_path,
signature_fn=tf.contrib.learn.utils.export.classification_signature_fn)

我收到以下警告:

  

警告:tensorflow:调用导出(来自   tensorflow.contrib.learn.python.learn.estimators.estimator)   use_deprecated_input_fn = True已弃用,之后将被删除   2016年9月23日。更新说明:input_fn的签名   出口接受的变化与使用的变化一致   tf.Learn Estimator的火车/评估。 input_fn和input_feature_key   将成为必需的args,而use_deprecated_input_fn将默认为   错误并完全删除。

问题是,我现在可以忽略这个警告吗?另一个问题,我如何为客户编写代码?如何正确准备protobuf?

我看到,对于mnist客户端,他们按如下方式准备protobuf

request = predict_pb2.PredictRequest()
request.model_spec.name = 'mnist'
image, label = test_data_set.next_batch(1)
request.inputs['images'].CopyFrom(
    tf.contrib.util.make_tensor_proto(image[0], shape=[1, image[0].size]))

如何对contrib.learn估算工具中使用的要素列执行相同操作?例如,如果要素列如下?

  country = sparse_column_with_vocabulary_file("country", VOCAB_FILE)
  age = real_valued_column("age")
  click_bucket = bucketized_column(real_valued_column("historical_click_ratio"),
                                   boundaries=[i/10. for i in range(10)])
  country_x_click = crossed_column([country, click_bucket], 10)
  feature_columns = set([age, click_bucket, country_x_click])

...来自https://www.tensorflow.org/versions/r0.11/tutorials/wide/index.html等教程之一的input_fn将提供从客户端发送的实际数据

更新

我运行了导出/客户端组合,但结果看起来不对。部分导出代码如下:

def my_classification_signature_fn(examples, unused_features, predictions):
  """Creates classification signature from given examples and predictions.
  Args:
    examples: `Tensor`.
    unused_features: `dict` of `Tensor`s.
    predictions: `Tensor` or dict of tensors that contains the classes tensor
      as in {'classes': `Tensor`}.
  Returns:
    Tuple of default classification signature and empty named signatures.
  Raises:
    ValueError: If examples is `None`.
  """
  if examples is None:
    raise ValueError('examples cannot be None when using this signature fn.')

  if isinstance(predictions, dict):
    default_signature = exporter.classification_signature(
        examples, classes_tensor=predictions['classes'])
  else:
    print examples
    print predictions
    default_signature = exporter.classification_signature(
        examples, classes_tensor=predictions)
  named_graph_signatures={
        'inputs': exporter.generic_signature({'x_values': examples}),
        'outputs': exporter.generic_signature({'preds': predictions})}    
  return default_signature, named_graph_signatures

def export_input_fn(df,feature_defs,batch_size):
    input_features = input_fn(df, feature_defs)
    input_features["input_feature_dummy"] = tf.constant(np.array([["__SOME_VAL__" for __ in range(len(df.columns))] for _ in range(batch_size) ]),
                                      dtype=tf.string,
                                      shape=[batch_size, len(df.columns)])
    return input_features,None

model.export(
    "export_8",
    input_fn=lambda : export_input_fn(extdf,training_run_data["features"],20),
    input_feature_key="input_feature_dummy",
    signature_fn=my_classification_signature_fn,
    use_deprecated_input_fn=False)

部分客户端代码如下:

pred_df = pd.read_csv("{}/client_test.csv".format(work_dir)).sort_index(axis=1)
host, port = hostport.split(':')
channel = implementations.insecure_channel(host, int(port))
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)
result_counter = _ResultCounter(num_tests, concurrency)


x_values = numpy.array([["__SOME_VAL__" for __ in range(len(pred_df.columns))] for _ in range(20)])
labels = pred_df["__label"].values
input_values = pred_df.astype('str').values
input_size = min(20, input_values.shape[0])
x_values[0:input_size] = input_values[0:input_size]

labels_test = numpy.zeros(20)
labels_test[0:input_size] = labels[0:input_size]

request = predict_pb2.PredictRequest()
request.model_spec.name = 'capture_process'

request.inputs['x_values'].CopyFrom(
    tf.contrib.util.make_tensor_proto(x_values, shape=x_values.shape))


result_counter.throttle()
result_future = stub.Predict.future(request, 5.0)  # 5 seconds

result_future.add_done_callback(
    _create_rpc_callback(labels_test, result_counter))

正如您所看到的,我正在尝试将从pandas数据帧中提取的所有输入列作为一个称为x_values的张量原型发送。这是正确的还是应该为每个要素列创建输入?

0 个答案:

没有答案