我正在尝试使用新的Java API从磁盘读取模型。
The one example使用Tensorflow的Java API演示了如何读取具有图形定义和参数权重的.pb
模型文件。
在Python方面,Tensorflow建议使用Saver
对象将模型保存到磁盘。它创建一个.meta
文件,该文件具有定义并具有权重的.data
个文件。在Python中,我使用new_saver=tf.train.import_meta_graph(var_filename)
new_saver.restore(sess, model_filename)
从磁盘读取模型。
如何在Java API中执行此操作?
答案 0 :(得分:0)
SavedModelBundle
课程可能就是你要找的。特别是,SavedModelBundle.load()
将返回Session
,您可以使用它来执行保存的模型。
请注意,此功能最近才添加到Java API中,因此它尚未存在于二进制版本中,因此在TensorFlow 1.1发布之前,您必须build the Java API from source。
答案 1 :(得分:0)
我正在做类似的事情,使用python接口在hadoop集群上训练模型,并使用模型和学习参数在java中进行预测。
保存模型
你必须使用SavedModelBuilder。在这里你找到指导: https://tensorflow.github.io/serving/serving_basic.html 并且您可以使用他们的mnist示例构造函数签名 https://github.com/tensorflow/serving/blob/master/tensorflow_serving/example/mnist_saved_model.py
java方面的用法非常简单:
SavedModelBundle load = SavedModelBundle.load(modelDir, "serve");
float[][] resultArray;
try (Graph g = load.graph()) {
try (Session s = load.session();
Tensor result = s.runner().feed("data", data).fetch("prediction").run().get(0)) {
resultArray = result.copyTo(new float[10][1]);
}
}
load.close();
return resultArray;
要获取Feed的名称和获取的操作,您可以打印签名,并使用输入和输出值名称。
print(prediction_signature)