使用tensorflow“保存模型”api对java和python中加载的模型进行错误的预测

时间:2018-05-17 02:39:47

标签: tensorflow tensorflow-serving

我正在尝试在Java中加载一个用python训练的模型,并使用保存的模型api(from tensorflow.python.saved_model)进行保存。

我可以在单独的Python脚本和Java中加载它,但Java版本中的预测是错误的。

我用一个简单的模型编写了一个快速示例项目,演示了“bug”(我希望我的误解)。

Python:OrTraining.py

使用Saved Model Api训练后保存模型。

builders = saved_model_builder.SavedModelBuilder(export_path)
builders.add_meta_graph_and_variables(sess, ["or"], signature_def_map={
    "predict": tf.saved_model.signature_def_utils.predict_signature_def(
        inputs= {"images": x_placeholder},
        outputs= {"scores": hypothesis_function})
    })
builders.save()

https://github.com/JsFlo/DebuggingSavedModelJava/blob/master/OrTraining.py

Python:OrLoadSavedModel.py

使用Saved Model Api。在单独的脚本中加载模型。

with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, ["or"], "orTrainingModels")
graph = tf.get_default_graph()
print(graph.get_operations())
x_placeholder = graph.get_tensor_by_name("or_inputs:0")
hypothesis_function = graph.get_tensor_by_name("hypothesis_output:0")
# sess.run("init")
print(sess.run(hypothesis_function, feed_dict={x_placeholder: np.array([
    np.array([1, 0]),
    np.array([0, 1]),
    np.array([0, 0]),
    np.array([1, 1]),
])}))

https://github.com/JsFlo/DebuggingSavedModelJava/blob/master/OrLoadSavedModel.py

Java:OrLoadSavedModel.java

加载

 SavedModelBundle savedModelBundle = SavedModelBundle.load("./orTrainingModels", "or");
 Session session = savedModelBundle.session();

运行

Tensor result = session.runner()
            .feed("or_inputs", tensorInput)
            .fetch("hypothesis_output")
            .run().get(0);

https://github.com/JsFlo/DebuggingSavedModelJava/blob/master/src/main/java/OrLoadSavedModel.java

java版本和python版本加载并运行图形没有问题,但java版本不会输出正确的预测。

起初我认为这是因为权重/偏差没有被加载但是我能够“运行”java版本中的权重/偏见操作,并且看到它具有我在训练后的python脚本。

检查java中的权重(https://github.com/JsFlo/DebuggingSavedModelJava

Tensor result = session.runner()
            .fetch("da_weights")
            .run().get(0);

1 个答案:

答案 0 :(得分:1)

这对于我提供数据的方式来说是一个问题.Tensorflow不想创建Boxed Types的张量(整数与整数/浮点数与浮点数)并且有检查看看你是否试图传入盒装类型,但似乎检查并不全面。

*从https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java *

进行测试
@Test
public void testCreateFromArrayOfBoxed() {
    Integer[] vector = new Integer[] {1, 2, 3, 4};
    try (Tensor<Integer> t = Tensor.create(vector, Integer.class)) {
        fail("Tensor.create() should fail because it was given an array of boxed values"); 
    } catch (IllegalArgumentException e) {
     // The expected exception
   }
}

以下是我的问题的一个例子:

    Float[] input = new Float[]{0f, 1f};
    Tensor tensorOutput = Tensor.create(input);
    float[] floatOutput= new float[2];
    tensorOutput.copyTo(floatOutput);
    println(Arrays.toString(floatOutput)); // -7.377E30, -7.377E30


    float[] input = new float[]{0f, 1f};
    Tensor tensorOutput = Tensor.create(input);

    float[] floatOutput= new float[2];
    tensorOutput.copyTo(floatOutput);
    println(Arrays.toString(floatOutput)); // 0, 1