tensorflow.keras用python保存模型并用Java加载

时间:2018-12-12 14:37:11

标签: java python tensorflow keras

我有一个经过微调的vgg模型,我使用 tensorflow.keras 功能API创建了模型,并使用 tf.contrib.saved_model.save_keras_model 保存了模型。 因此,模型使用以下结构保存:资产文件夹,其中包含 saved_model.json 文件, saved_model.pb 文件,以及变量文件夹,其中包含 checkpoint variables.data-00000-of-00001 variables.index

我可以轻松地在python中加载模型并使用 tf.contrib.saved_model.load_keras_model(saved_model_path)获得预测,但是我不知道如何在JAVA中加载模型。我在Google上搜索了很多,发现此How to export Keras .h5 to tensorflow .pb?可以导出为pb文件,然后通过此链接Loading in Java进行加载。我无法冻结图,也尝试使用simple_save,但tensorflow.keras不支持simple_save( AttributeError:模块'tensorflow.contrib.saved_model'没有属性'simple_save')。因此有人可以帮助我确定在JAVA中加载我的模型(tensorflow.keras功能性API模型)需要采取哪些步骤。

我的savedmodel.pb文件是否足以在Java端加载?我需要创建输入/输出占位符吗?那怎么出口呢?
感谢您的帮助。

1 个答案:

答案 0 :(得分:1)

如果您有一个以SavedModel格式保存的模型(看起来确实如此,并且tf.contrib.saved_model.save_keras_model之类的东西可以帮助创建),那么在Java中,您可以使用SavedModelBundle.load来加载和提供服务。您无需“冻结”模型。

您可能会发现以下有用:

但是基本思想是您的代码将类似于:

try (SavedModelBundle model = SavedModelBundle.load("<directory>", "serve")) {
  try (Tensor<?> input = makeInputTensor();
       Tensor<?> output = model.session().runner().feed("INPUT_TENSOR", input).fetch("OUTPUT_TENSOR", output).run().get(0)) {
  // Use output
  }
}

其中"INPUT_TENSOR""OUTPUT_TENSOR"是TensorFlow图中输入和输出节点的名称。在安装TensorFlow for Python时安装的saved_model_cli命令行工具可以在模型中显示这些张量的名称。

请注意,使用TensorFlow Java API可能比使用另一个评论者建议的使用TensorFlow Lite更适合服务器/桌面应用程序。这是因为TensorFLow Lite运行时虽然已针对小型设备进行了优化(在内存占用等方面),但仍无法导出所有模型。 TensorFlow Java API使用完全相同的运行时,因此具有与TensorFlow for Python完全相同的功能。

希望有帮助。