当Java和Python tensorflow版本都是1.2.0时,似乎我们可以使用SavedModelBundle(Java)和Saved Model API(Python)来保存Python tensorflow中训练有素的模型并在Java tensorflow中加载模型(不是与Maven)。
但是,当Python的版本低于1.0时,我无法找到在Java中正确加载模型的方法。
我训练了一个模型并将其保存为Python tensorflow(0.9.0)中的.pb,.sd和.txt文件,并按照tensorflow网站中的example指令加载模型。但是,我收到以下错误:
Exception in thread "main" java.lang.IllegalStateException: Attempting
to use uninitialized value policy/mean_network/hidden_1/b
[[Node: _retval_policy/mean_network/hidden_1/b_0_0 =
_Retval[T=DT_FLOAT, index=0,
_device="/job:localhost/replica:0/task:0/cpu:0"]
(policy/mean_network/hidden_1/b)]]
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:285)
at org.tensorflow.Session$Runner.run(Session.java:235)
at Carpole.executeGraph(Carpole.java:42)
at Carpole.main(Carpole.java:30)
有没有人知道如何在不使用Saved Model API的情况下在最新版本中正确加载Java中的预训练模型(因为我再也找不到以前的版本API)?
提前致谢!
这是我保存的Python代码:
with tf.Session() as sess:
self.saver = tf.train.Saver(tf.all_variables())
sess.run(tf.initialize_all_variables())
…..
saver_def = self.saver.as_saver_def()
print(saver_def.filename_tensor_name)
print(saver_def.restore_op_name)
self.saver.save(sess, 'trained_model'+str(itr)+'.sd')
tf.train.write_graph(sess.graph_def, '.', 'trained_model'+str(itr)+'.pb', as_text=False)
tf.train.write_graph(sess.graph_def, '.', 'trained_model'+str(itr)+'.txt', as_text=True)
这是我的Java代码
public static void main(String[] args) throws Exception {
String dataDirPath = args[0];
byte[] graphDef = readAllBytesOrExit(Paths.get(dataDirPath, "trained_model10.pb"));
List<String> labels = readAllLinesOrExit(Paths.get(dataDirPath, "trained_model10.txt"));
float[] vector = new float[4];
vector[0] = (float) -0.09341373;
vector[1] = (float) -0.07540844;
vector[2] = (float) 0.00930138;
vector[3] = (float) -0.14317159;
Tensor input = Tensor.create(vector);
float[] labelProbabilities = executeGraph(graphDef, input);
int bestLabelIdx = maxIndex(labelProbabilities);
System.out.println(String.format("BEST MATCH: %s (%.2f%% likely)",labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f));
}
private static float[] executeGraph(byte[] graphDef, Tensor input_tensor) {
try (Graph g = new Graph()) {
g.importGraphDef(graphDef);
System.out.println(g);
try (Session s = new Session(g); Tensor result = s.runner().feed("policy/mean_network/input/input",input_tensor).fetch("policy/mean_network/hidden_1/b").run().get(0)) {
final long[] rshape = result.shape();
}
int nlabels = (int) rshape[1];
return result.copyTo(new float[1][nlabels])[0];
}
}
答案 0 :(得分:0)
一般情况下,1.0之前的TensorFlow版本不能保证使用TensorFlow版本&gt; = 1.0(根据TensorFlow Version Semantics并使用语义版本控制)。
也就是说,查看您提供的代码片段,您在Java中加载计算图,但是没有加载已保存的变量,因此您会收到一个异常,抱怨变量尚未初始化。
将图形和已保存的变量封装到单个包中是SavedModel格式的用途。但是,如果您不能使用它,并且如果您只需要在Java中推断图形,那么您可能需要考虑&#34;冻结&#34;图表,然后在Java中加载。冻结的图形将包含单个文件中的所有变量值。
您可以尝试使用freeze_graph
库(来自0.9版本分支)来保存这样的图表。
使用TensorFlow的1.0之前版本可能是一个挑战。如果可能,我强烈建议您将模型移至TensorFlow版本&gt; = 1.0,然后可以使用API稳定性保证。
希望有所帮助。