使用SavedModelBundle在Java中提供初始模型v3

时间:2017-07-25 21:20:27

标签: java tensorflow tensorflow-serving

使用org.tensorflow:tensorflow:1.3.0-rc0。

我根据教程https://tensorflow.github.io/serving/serving_inception

从检查点生成了初始模型
inception_saved_model --checkpoint_dir=/root/xmod/inception-v3

这样就行了,生成了saved_model.pb和带有数据的variables /子目录,并将所有这些内容移动到/ tmp / inception-model目录。 现在我试图通过实质上转换https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java来使用这个模型 我正在加载这样的模型而没有错误:

SavedModelBundle modelBundle = SavedModelBundle.load("/tmp/inception-model", "serve");

现在我正在尝试制定查询(类似于此https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java#L112),但我仍然试图弄清楚如何使用Feed和fetch方法:

    private static float[] executeInceptionGraph(SavedModelBundle modelBundle, Tensor image) throws Exception {
     Tensor result = modelBundle.session().runner().feed(???).fetch(???).run().get(0);

非常感谢任何帮助如何编写此查询。

2 个答案:

答案 0 :(得分:0)

您需要在图表中提供与其节点名称相关联的输入(此处是张量图像),从您发布的链接看,教程似乎使用“图像”(请参阅​​此处https://github.com/tensorflow/serving/blob/master/tensorflow_serving/example/inception_client.py#L49,代码使用python查询到教程https://tensorflow.github.io/serving/serving_inception)中构建的服务器。 然后你也可以用它的名字来获取你的输出节点,在这里看一下服务器响应的样本https://tensorflow.github.io/serving/serving_inception你可以得到“类”或“分数”,这取决于你想要的那个。

因此,两个命令中的一个应该起作用:

Tensor result = modelBundle.session().runner().feed("images", image).fetch("classes").run().get(0);

OR

 Tensor result = modelBundle.session().runner().feed("images", image).fetch("scores").run().get(0);

答案 1 :(得分:0)

我发现它仅适用于冷冻模型。 fetch方法的参数是freeze_graph参数output_node_names中使用的参数。见https://github.com/tensorflow/models/blob/master/slim/export_inference_graph.py#L32