使用org.tensorflow:tensorflow:1.3.0-rc0。
我根据教程https://tensorflow.github.io/serving/serving_inception:
从检查点生成了初始模型inception_saved_model --checkpoint_dir=/root/xmod/inception-v3
这样就行了,生成了saved_model.pb和带有数据的variables /子目录,并将所有这些内容移动到/ tmp / inception-model目录。 现在我试图通过实质上转换https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java来使用这个模型 我正在加载这样的模型而没有错误:
SavedModelBundle modelBundle = SavedModelBundle.load("/tmp/inception-model", "serve");
现在我正在尝试制定查询(类似于此https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java#L112),但我仍然试图弄清楚如何使用Feed和fetch方法:
private static float[] executeInceptionGraph(SavedModelBundle modelBundle, Tensor image) throws Exception {
Tensor result = modelBundle.session().runner().feed(???).fetch(???).run().get(0);
非常感谢任何帮助如何编写此查询。
答案 0 :(得分:0)
您需要在图表中提供与其节点名称相关联的输入(此处是张量图像),从您发布的链接看,教程似乎使用“图像”(请参阅此处https://github.com/tensorflow/serving/blob/master/tensorflow_serving/example/inception_client.py#L49,代码使用python查询到教程https://tensorflow.github.io/serving/serving_inception)中构建的服务器。 然后你也可以用它的名字来获取你的输出节点,在这里看一下服务器响应的样本https://tensorflow.github.io/serving/serving_inception你可以得到“类”或“分数”,这取决于你想要的那个。
因此,两个命令中的一个应该起作用:
Tensor result = modelBundle.session().runner().feed("images", image).fetch("classes").run().get(0);
OR
Tensor result = modelBundle.session().runner().feed("images", image).fetch("scores").run().get(0);
答案 1 :(得分:0)
我发现它仅适用于冷冻模型。 fetch方法的参数是freeze_graph参数output_node_names中使用的参数。见https://github.com/tensorflow/models/blob/master/slim/export_inference_graph.py#L32