使用Java tensorflow v.1.2.0中的Python tensorflow v.0.9.0加载预训练模型

时间:2017-07-04 18:49:20

标签: java python tensorflow

当Java和Python tensorflow版本都是1.2.0时,似乎我们可以使用SavedModelBundle(Java)和Saved Model API(Python)来保存Python tensorflow中训练有素的模型并在Java tensorflow中加载模型(不是与Maven)。

但是,当Python的版本低于1.0时,我无法找到在Java中正确加载模型的方法。

我训练了一个模型并将其保存为Python tensorflow(0.9.0)中的.pb,.sd和.txt文件,并按照tensorflow网站中的example指令加载模型。但是,我收到以下错误:

Exception in thread "main" java.lang.IllegalStateException: Attempting 
to use uninitialized value policy/mean_network/hidden_1/b
            [[Node: _retval_policy/mean_network/hidden_1/b_0_0 = 
_Retval[T=DT_FLOAT, index=0, 
_device="/job:localhost/replica:0/task:0/cpu:0"]
(policy/mean_network/hidden_1/b)]]
            at org.tensorflow.Session.run(Native Method)
            at org.tensorflow.Session.access$100(Session.java:48)
            at org.tensorflow.Session$Runner.runHelper(Session.java:285)
            at org.tensorflow.Session$Runner.run(Session.java:235)
            at Carpole.executeGraph(Carpole.java:42)
            at Carpole.main(Carpole.java:30)

有没有人知道如何在不使用Saved Model API的情况下在最新版本中正确加载Java中的预训练模型(因为我再也找不到以前的版本API)?

提前致谢!

这是我保存的Python代码:

with tf.Session() as sess:
    self.saver = tf.train.Saver(tf.all_variables())
    sess.run(tf.initialize_all_variables())
    …..
    saver_def = self.saver.as_saver_def()
    print(saver_def.filename_tensor_name)
    print(saver_def.restore_op_name)

    self.saver.save(sess, 'trained_model'+str(itr)+'.sd')
    tf.train.write_graph(sess.graph_def, '.', 'trained_model'+str(itr)+'.pb', as_text=False)
    tf.train.write_graph(sess.graph_def, '.', 'trained_model'+str(itr)+'.txt', as_text=True)

这是我的Java代码

public static void main(String[] args) throws Exception {
    String dataDirPath = args[0];
    byte[] graphDef = readAllBytesOrExit(Paths.get(dataDirPath, "trained_model10.pb"));
    List<String> labels = readAllLinesOrExit(Paths.get(dataDirPath, "trained_model10.txt"));
    float[] vector = new float[4];
    vector[0] = (float) -0.09341373;
    vector[1] = (float) -0.07540844;
    vector[2] = (float)  0.00930138;
    vector[3] = (float) -0.14317159;
    Tensor input = Tensor.create(vector);

    float[] labelProbabilities = executeGraph(graphDef, input);
    int bestLabelIdx = maxIndex(labelProbabilities);
    System.out.println(String.format("BEST MATCH: %s (%.2f%% likely)",labels.get(bestLabelIdx), labelProbabilities[bestLabelIdx] * 100f));
}

private static float[] executeGraph(byte[] graphDef, Tensor input_tensor) {
    try (Graph g = new Graph()) {
        g.importGraphDef(graphDef);
        System.out.println(g);
        try (Session s = new Session(g); Tensor result = s.runner().feed("policy/mean_network/input/input",input_tensor).fetch("policy/mean_network/hidden_1/b").run().get(0)) {
        final long[] rshape = result.shape();
        }
        int nlabels = (int) rshape[1];
        return result.copyTo(new float[1][nlabels])[0];
    }
}

1 个答案:

答案 0 :(得分:0)

一般情况下,1.0之前的TensorFlow版本不能保证使用TensorFlow版本&gt; = 1.0(根据TensorFlow Version Semantics并使用语义版本控制)。

也就是说,查看您提供的代码片段,您在Java中加载计算图,但是没有加载已保存的变量,因此您会收到一个异常,抱怨变量尚未初始化。

将图形和已保存的变量封装到单个包中是SavedModel格式的用途。但是,如果您不能使用它,并且如果您只需要在Java中推断图形,那么您可能需要考虑&#34;冻结&#34;图表,然后在Java中加载。冻结的图形将包含单个文件中的所有变量值。

您可以尝试使用freeze_graph库(来自0.9版本分支)来保存这样的图表。

使用TensorFlow的1.0之前版本可能是一个挑战。如果可能,我强烈建议您将模型移至TensorFlow版本&gt; = 1.0,然后可以使用API​​稳定性保证。

希望有所帮助。