如何在Tensorflow的Java API中使用`saver.save`加载模型

时间:2017-03-15 23:15:21

标签: java tensorflow

我正在尝试使用新的Java API从磁盘读取模型。

The one example使用Tensorflow的Java API演示了如何读取具有图形定义和参数权重的.pb模型文件。

在Python方面,Tensorflow建议使用Saver对象将模型保存到磁盘。它创建一个.meta文件,该文件具有定义并具有权重的.data个文件。在Python中,我使用new_saver=tf.train.import_meta_graph(var_filename) new_saver.restore(sess, model_filename)从磁盘读取模型。

如何在Java API中执行此操作?

2 个答案:

答案 0 :(得分:0)

SavedModelBundle课程可能就是你要找的。特别是,SavedModelBundle.load()将返回Session,您可以使用它来执行保存的模型。

请注意,此功能最近才添加到Java API中,因此它尚未存在于二进制版本中,因此在TensorFlow 1.1发布之前,您必须build the Java API from source

答案 1 :(得分:0)

我正在做类似的事情,使用python接口在hadoop集群上训练模型,并使用模型和学习参数在java中进行预测。

java方面的用法非常简单:

SavedModelBundle load = SavedModelBundle.load(modelDir, "serve");
        float[][] resultArray;
        try (Graph g = load.graph()) {
            try (Session s = load.session();
                 Tensor result = s.runner().feed("data", data).fetch("prediction").run().get(0)) {
                resultArray = result.copyTo(new float[10][1]);
            }
        }
        load.close();
        return resultArray;

要获取Feed的名称和获取的操作,您可以打印签名,并使用输入和输出值名称。

print(prediction_signature)

https://github.com/tensorflow/serving/blob/master/tensorflow_serving/example/mnist_saved_model.py#L119