在Java

时间:2019-02-01 07:13:31

标签: java python tensorflow

我正在尝试用Java加载张量流模型。

tf.saved_model.simple_save(
                        sess,
                        "/tmp/model/"+timestamp,
                        inputs={"input_x" : cnn.input_x},
                        outputs={"input_y" : cnn.input_y})

这就是我在python中保存张量流模型的方式。

public static void main( String[] args ) throws IOException
    {
        // good idea to print the version number, 1.2.0 as of this writing
        System.out.println(TensorFlow.version());        
        final int NUM_PREDICTIONS = 1;
            Random r = new Random();
            long[] shape = new long[] {1,56};
            IntBuffer buf = IntBuffer.allocate(1*56);
            for (int i = 0; i < 56; i++) {
               buf.put(r.nextInt());
            }
            buf.flip();


        // load the model Bundle
        try (SavedModelBundle b = SavedModelBundle.load("/tmp/model/1549001254", "serve")) {

        Session sess = b.session();

         // run the model and get the result, 4.0f.
                try(Tensor x = Tensor.create(shape, buf)){
           float[] result = sess.runner()
             .feed("input_x", x)
             .fetch("input_y")
             .run()
             .get(0)
                     .copyTo(new float[1][2])[0];

           // print out the result.
           System.out.println(result[0]);
                }

        }                
    }

这就是我在Java中加载它的方式。

The given SavedModel SignatureDef contains the following input(s):
  inputs['input_x'] tensor_info:
      dtype: DT_INT32
      shape: (-1, 56)
      name: input_x:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['input_y'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1, 2)
      name: input_y:0
Method name is: tensorflow/serving/predict

输入和输出保存得很好。

1.12.0
2019-02-01 15:58:59.065677: I tensorflow/cc/saved_model/reader.cc:31] Reading SavedModel from: /tmp/model/1549001254
2019-02-01 15:58:59.072601: I tensorflow/cc/saved_model/reader.cc:54] Reading meta graph with tags { serve }
2019-02-01 15:58:59.085912: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2
2019-02-01 15:58:59.132271: I tensorflow/cc/saved_model/loader.cc:162] Restoring SavedModel bundle.
2019-02-01 15:58:59.199331: I tensorflow/cc/saved_model/loader.cc:138] Running MainOp with key legacy_init_op on SavedModel bundle.
2019-02-01 15:58:59.199435: I tensorflow/cc/saved_model/loader.cc:259] SavedModel load for tags { serve }; Status: success. Took 133774 microseconds.
Exception in thread "main" java.lang.IllegalArgumentException: You must feed a value for placeholder tensor 'input_y' with dtype float and shape [?,2]
     [[{{node input_y}} = Placeholder[_output_shapes=[[?,2]], dtype=DT_FLOAT, shape=[?,2], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
    at org.tensorflow.Session.run(Native Method)
    at org.tensorflow.Session.access$100(Session.java:48)
    at org.tensorflow.Session$Runner.runHelper(Session.java:314)
    at org.tensorflow.Session$Runner.run(Session.java:264)
    at Use_model.main(Use_model.java:38)

但是它无法加载模型...错误消息是这样的。

我不知道问题出在哪里以及如何解决。

1 个答案:

答案 0 :(得分:0)

您的代码中对input_y有一些困惑。例外说明:

You must feed a value for placeholder tensor 'input_y' with dtype float and shape [?,2]

这意味着,在您的python代码中,input_y被定义为占位符。我猜这是包含input_x项目标签的占位符。然后应该在损失函数中使用input_y,以将您的cnn的最后一层(我们称之为cnn.output)与实际标签(cnn.input_y)进行比较,例如:

loss = tf.square(cnn.input_y - cnn.output)

然后,您的python代码应将cnn.output保存在输出字典中,而不是cnn.input_y:

tf.saved_model.simple_save(
                    sess,
                    "/tmp/model/"+timestamp,
                    inputs={"input_x" : cnn.input_x},
                    outputs={"output" : cnn.output})

然后在您的Java代码中,应获取“输出”:

float[] result = sess.runner()
         .feed("input_x", x)
         .fetch("output")
         .run()
         .get(0)
                 .copyTo(new float[1][2])[0];