Inferencing with Tensorflow Serving using Java

时间:2019-04-16 23:28:58

标签: tensorflow tensorflow-serving grpc-java

We are transitioning an existing Java production code to use Tensorflow Serving (TFS) for inferencing. We have already retrained our models and saved them using the new SavedModel format (no more frozen graphs!!).
From the documentation that I have read, TFS does not directly support Java. However it does provide a gRPC interface, and that does provide a Java interface.

My question, what are the steps involved in bringing up a Java application to use TFS.

[Edit: moved steps to a solution]

1 个答案:

答案 0 :(得分:1)

由于文档和示例仍然有限,因此花了四天的时间将它们拼凑在一起。
我确信有更好的方法可以做到这一点,但这是我到目前为止发现的:

  • 我在github上克隆了tensorflow/tensorflowtensorflow/servinggoogle/protobuf仓库。
  • 我使用protoc compilergrpc-java protobuf plugin编译了以下protobuf文件。我讨厌这样一个事实:要编译的分散的.proto文件太多,但是我希望包含最小的集合,并且要在各个目录中绘制的这么多不需要的.proto文件这是编译Java应用程序所需的最小设置:
    • serving_repo/tensorflow_serving/apis/*.proto
    • serving_repo/tensorflow_serving/config/model_server_config.proto
    • serving_repo/tensorflow_serving/core/logging.proto
    • serving_repo/tensorflow_serving/core/logging_config.proto
    • serving_repo/tensorflow_serving/util/status.proto
    • serving_repo/tensorflow_serving/sources/storage_path/file_system_storage_path_source.proto
    • serving_repo/tensorflow_serving/config/log_collector_config.proto
    • tensorflow_repo/tensorflow/core/framework/tensor.proto
    • tensorflow_repo/tensorflow/core/framework/tensor_shape.proto
    • tensorflow_repo/tensorflow/core/framework/types.proto
    • tensorflow_repo/tensorflow/core/framework/resource_handle.proto
    • tensorflow_repo/tensorflow/core/example/example.proto
    • tensorflow_repo/tensorflow/core/protobuf/tensorflow_server.proto
    • tensorflow_repo/tensorflow/core/example/feature.proto
    • tensorflow_repo/tensorflow/core/protobuf/named_tensor.proto
    • tensorflow_repo/tensorflow/core/protobuf/config.proto
  • 请注意,即使存在protocgrpc-java仍会编译,但是大多数关键入口点会神秘地丢失。如果缺少PredictionServiceGrpc.java,则不会执行grpc-java
  • 命令行示例(为了方便阅读,插入了换行符):
$ ./protoc -I=/Users/foobar/protobuf_repo/src \
   -I=/Users/foobar/tensorflow_repo \   
   -I=/Users/foobar/tfserving_repo \  
   -plugin=protoc-gen-grpc-java=/Users/foobar/protoc-gen-grpc-java-1.20.0-osx-x86_64.exe \
   --java_out=src \
   --grpc-java_out=src \
   /Users/foobar/tfserving_repo/tensorflow_serving/apis/*.proto
  • 在gRPC documentation之后,我创建了一个Channel和一个存根:
ManagedChannel mChannel;
PredictionServiceGrpc.PredictionServiceBlockingStub mBlockingstub;
mChannel = ManagedChannelBuilder.forAddress(host,port).usePlaintext().build();
mBlockingstub = PredictionServiceGrpc.newBlockingStub(mChannel);
  • 我按照几个文档整理了以下步骤:
    • gRPC documents讨论存根(阻塞和异步)
    • article概述了该过程,但使用了Python
    • 此示例code对于NewBuilder语法示例至关重要。
  • Maven进口是:
    • io.grpc:grpc-all
    • org.tensorflow:libtensorflow
    • org.tensorflow:proto
    • com.google.protobuf:protobuf-java
  • 以下是示例代码:
// Generate features TensorProto
TensorProto.Builder featuresTensorBuilder = TensorProto.newBuilder();

TensorShapeProto.Dim featuresDim1  = TensorShapeProto.Dim.newBuilder().setSize(1).build();
TensorShapeProto     featuresShape = TensorShapeProto.newBuilder().addDim(featuresDim1).build();
featuresTensorBuilder.setDtype(org.tensorflow.framework.DataType).setTensorShape(featuresShape);
TensorProto featuresTensorProto = featuresTensorBuilder.build();


// Now prepare for the inference request over gRPC to the TF Serving server
com.google.protobuf.Int64Value version = com.google.protobuf.Int64Value.newBuilder().setValue(mGraphVersion).build();

Model.ModelSpec.Builder model = Model.ModelSpec
                                     .newBuilder()
                                     .setName(mGraphName)
                                     .setVersion(version);  // type = Int64Value
Model.ModelSpec     modelSpec = model.build();

Predict.PredictRequest request;
request = Predict.PredictRequest.newBuilder()
                                .setModelSpec(modelSpec)
                                .putInputs("image", featuresTensorProto)
                                .build();

Predict.PredictResponse response;

try {
    response = mBlockingstub.predict(request);
    // Refer to https://github.com/thammegowda/tensorflow-grpc-java/blob/master/src/main/java/edu/usc/irds/tensorflow/grpc/TensorflowObjectRecogniser.java

    java.util.Map<java.lang.String, org.tensorflow.framework.TensorProto> outputs = response.getOutputsOrDefault();
    for (java.util.Map.Entry<java.lang.String, org.tensorflow.framework.TensorProto> entry : outputs.entrySet()) {
        System.out.println("Response with the key: " + entry.getKey() + ", value: " + entry.getValue());
    }
} catch (StatusRuntimeException e) {
    logger.log(Level.WARNING, "RPC failed: {0}", e.getStatus());
    success = false;
}