我按照下面提供的说明在Java中加载训练有素的Tensorflow模型:
How to import an saved Tensorflow model train using tf.estimator and predict on input data
我成功保存了我的python模型并将其导入我的java代码中。
以下是我的python代码中保存模型的部分:
# Export trained model
def serving_input_receiver_fn():
inputs = {"my_feature": tf.placeholder(shape=[None, 1], dtype=tf.float32)}
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
export_dir = kmeans.export_savedmodel(
export_dir_base="/tmp/path_to_model_directory",
serving_input_receiver_fn=serving_input_receiver_fn,
assets_extra={"file_name":"/tmp/path_to_file"
})
这是用于读取模型,进行预测和读取结果的java代码:
Tensor x = Tensor.create(
new long[] {1, 1},
FloatBuffer.wrap(
new float[] {
10000.0f
}));
final String xName = "inputs_tensor_info_name";
final String yName = "outputs_tensor_info_name";
List<Tensor<?>> outputTensor = session.runner()
.feed(xName, x)
.fetch(yName)
.run();
for (int i=0; i < outputTensor.size(); i++ ) {
Integer[] copy = new Integer[1];
System.out.println(outputTensor.get(i));
System.out.println(outputTensor.get(i).copyTo(copy));
}
但是,当尝试在数组中加载预测时,会发生以下异常:
Exception in thread "main" java.lang.IllegalArgumentException: cannot create non-scalar Tensors from arrays of boxed values
at org.tensorflow.Tensor.objectCompatWithType(Tensor.java:722)
at org.tensorflow.Tensor.throwExceptionIfTypeIsIncompatible(Tensor.java:742)
at org.tensorflow.Tensor.copyTo(Tensor.java:450)
at com.pivot.maven.quickstart.HelloTF.main(HelloTF.java:105)
如果没有弄明白如何解决这个问题,如果有人指出我正确的方向,我将不胜感激。 谢谢。