使用Java API超出范围0的切片索引0

时间:2020-04-23 14:38:41

标签: java tensorflow

我已经生成了一个SavedModel,可以与以下Python代码一起使用

import base64
import numpy as np
import tensorflow as tf
​
​
fn_load_image = lambda filename: np.array([base64.urlsafe_b64encode(open(filename, "rb").read())])
filename='test.jpg'
with tf.Session() as sess:
    loaded = tf.saved_model.loader.load(sess, ['serve'], 'tools/base64_model/1')
    image = fn_load_image(filename)
    p = sess.run('predictions:0', feed_dict={"input:0": image})
    print(p)

这给了我期望的值。

在同一模型上使用以下Java代码时

    // load the model Bundle
    try (SavedModelBundle b = SavedModelBundle.load("tools/base64_model/1",
            "serve")) {

        // create the session from the Bundle
        Session sess = b.session();

        // base64 representation of JPG
        byte[] content = IOUtils.toByteArray(new FileInputStream(new File((args[0]))));

        String encodedString = Base64.getUrlEncoder().encodeToString(content);

        Tensor t = Tensors.create(encodedString);

        // run the model and get the classification
        final List<Tensor<?>> result = sess.runner().feed("input", 0, t).fetch("predictions", 0).run();

        // print out the result.
        System.out.println(result);
    }

应该等效,即我将图像的base64表示形式发送给模型,但出现异常

线程“ main”中的异常java.lang.IllegalArgumentException:切片 维度0的索引0超出范围。 [[{{node map / strided_slice}}]] 在org.tensorflow.Session.run(本机方法)在 org.tensorflow.Session.access $ 100(Session.java:48)在 org.tensorflow.Session $ Runner.runHelper(Session.java:326)在 org.tensorflow.Session $ Runner.run(Session.java:276)在 com.stolencamerafinder.storm.crawler.bolt.enrichments.HelloTensorFlow.main(HelloTensorFlow.java:35)

张量应该具有不同的内容吗?这是 saved_model_cli 告诉我的我的模型。

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['inputs'] tensor_info:
        dtype: DT_STRING
        shape: (-1)
        name: input:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['outputs'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 2)
        name: predictions:0
  Method name is: tensorflow/serving/predict

1 个答案:

答案 0 :(得分:1)

您的模型期望输入等级1的张量,而输入等级1的张量。

此行产生一个可变长度的标量张量(即DT_STRING)。

Tensor t = Tensors.create(encodedString);

但是,期望的张量为等级1,如您在此处通过形状(-1)所看到的,这意味着它期望的向量是各种元素的数量。

The given SavedModel SignatureDef contains the following input(s):
    inputs['inputs'] tensor_info:
        dtype: DT_STRING
        shape: (-1)
        name: input:0

因此,可能通过传递字符串数组来解决您的问题。仅当您将字符串作为字节数组数组传递时,才可以使用Tensors工厂来实现,如下所示:

// base64 representation of JPG
byte[] content = IOUtils.toByteArray(new FileInputStream(new File((args[0]))));
byte[] encodedBytes = Base64.getUrlEncoder().encode(content);
Tensor t = Tensors.create(new byte[][]{ encodedBytes });
...