无法在Java API中运行Tensorflow预测

时间:2017-06-20 08:44:21

标签: java python tensorflow

我正在尝试对我使用“Fensuning AlexNet with TensorFlow”训练的模型执行预测 https://kratzert.github.io/2017/02/24/finetuning-alexnet-with-tensorflow.html

我在Python中使用tf.saved_model.builder.SavedModelBuilder保存模型,并使用SavedModelBundle.load在Java中加载模型。 代码的主要部分是:

    SavedModelBundle smb = SavedModelBundle.load(path, "serve");
    Session s = smb.session();
    byte[] imageBytes = readAllBytesOrExit(Paths.get(path));
    Tensor image = constructAndExecuteGraphToNormalizeImage(imageBytes);
    Tensor result = s.runner().feed("input_tensor", image).fetch("fc8/fc8").run().get(0);
    final long[] rshape = result.shape();
    if (result.numDimensions() != 2 || rshape[0] != 1) {
        throw new RuntimeException(
                String.format(
                        "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
                        Arrays.toString(rshape)));
    }
    int nlabels = (int) rshape[1];
    float [] a =  result.copyTo(new float[1][nlabels])[0];`

我得到了这个例外:

  

线程“main”中的异常java.lang.IllegalArgumentException:您必须使用dtype float为占位符张量'Placeholder_1'提供值        [[Node:Placeholder_1 = Placeholder_output_shapes = [[]],dtype = DT_FLOAT,shape = [],_ device =“/ job:localhost / replica:0 / task:0 / cpu:0”]]

我看到上面的代码适用于某些人,我无法弄清楚这里缺少什么。 请注意,网络熟悉节点“input_tensor”和“fc8 / fc8”,因为它没有说它不知道它们。

1 个答案:

答案 0 :(得分:1)

从错误消息中,您正在使用的模型似乎需要提供另一个值(图中的节点名称为Placeholder_1且预期类型为浮点标量张量)。

您似乎已经自定义了模型(而不是跟随您逐字链接的文章)。也就是说,该文章显示了需要馈送的多个占位符,一个用于图像,另一个用于控制丢失。在文章中定义为:

keep_prob = tf.placeholder(tf.float32)

此占位符的值需要输入。如果您正在进行推理,那么您希望将keep_prob设置为1.0。类似的东西:

Tensor keep_prob = Tensor.create(1.0f);
Tensor result = s.runner()
  .feed("input_tensor", image)
  .feed("Placeholder_1", keep_prob)
  .fetch("fc8/fc8")
  .run()
  .get(0);

希望有所帮助。