如何从keras导出的Java中加载tensorflow .pb模型

时间:2019-07-02 20:57:27

标签: java tensorflow keras

我试图在Java中加载一个模型,该模型最初保存在Java的keras中,以便可以在Java中运行的现有生产系统中进行推理处理。

我没有看到在Java中轻松加载Keras h5模型的方法,因此我尝试首先使用simple_save将其转换为.pb文件,然后使用simple_save的默认标记加载它。我尝试直接使用Frozen_session例程和tf.train.write_graph保存该图,但是出现了相同的错误。

这是将模型保存到.pb文件的代码

# my model has two input tensors and one output tensor
inputs = {'input_1': model.inputs[0], 'input_2' : model.inputs[1]}
outputs = {'output_1' : model.outputs[0]}

tf.saved_model.simple_save(K.get_session(), 'output_dir', inputs=inputs, outputs=outputs)

这是我的Java代码,用于使用默认的标签save_model加载模型:

SavedModelBundle model = SavedModelBundle.load("output_dir", "serve");

这会导致错误:

线程“主” org.tensorflow.TensorFlowException中的异常: 在提供的导出目录路径:output_dir

上找不到SavedModel .pb或.pbtxt

知道我可能做错了什么吗?我知道simple_save已过时,但我只是想在这一点上使一切正常工作。

2 个答案:

答案 0 :(得分:0)

我查看了加载模型的the native source code,结果发现目录中有一个硬编码的文件名“ saved_model.pb”,或文本版本“ saved_model.pbtxt”(在我查看的文档中未指定)。

答案 1 :(得分:0)

现在,您可以使用Deep Java Library(DJL)在Java中加载Keras模型并运行推理。 DJL在内部使用tensorflow java并提供高级API以使其易于运行推理和训练。检出github存储库:https://github.com/awslabs/djl

有一个博客文章:https://towardsdatascience.com/detecting-pneumonia-from-chest-x-ray-images-e02bcf705dd6

可以找到演示项目:https://github.com/aws-samples/djl-demo/blob/master/pneumonia-detection/README.md