如何在Java中为TensorFlow创建TensorProto?

时间:2016-09-12 03:36:26

标签: java tensorflow protocol-buffers

现在我们使用张量流/服务进行推理。它公开了gRPC服务,我们可以从proto文件生成Java类。

现在我们可以从https://github.com/tensorflow/serving/blob/master/tensorflow_serving/apis/prediction_service.proto生成PreditionService但是如何从多维数组中构造TensorProto个对象。

我们有Python ndarray和C ++的一些例子。如果有人尝试使用Java,那将会很棒。

在Java中运行TensorFlow有一些工作。这是blog,但我不确定它是否有效或如何在没有依赖的情况下使用它。

1 个答案:

答案 0 :(得分:1)

TensorProto支持张量内容的两种表示形式:

  1. 各种repeated *_val字段(例如TensorProto.float_valTensorProto.int_val),它们以行主要顺序将内容存储为原始元素的线性数组。

  2. TensorProto.tensor_content字段,将内容存储为单字节数组,对应tensorflow::Tensor::AsProtoTensorContent()的结果。 (通常,此表示形式对应于tensorflow::Tensor的内存中表示形式,转换为字节数组,但DT_STRING类型的处理方式不同。)

  3. 使用第一种格式生成TensorProto对象可能更容易,但效率较低。假设您的Java程序中有一个名为float的2-D tensorData数组,您可以使用以下代码作为起点:

    float[][] tensorData = ...;
    TensorProto.Builder builder = TensorProto.newBuilder();
    
    // Set the shape and dtype fields.
    // ...
    
    // Set the float_val field.
    for (int i = 0; i < tensorData.length; ++i) {
        for (int j = 0; j < tensorData[i].length; ++j) {
            builder.addFloatVal(tensorData[i][j]);
        }
    }
    
    TensorProto tensorProto = builder.build();