我已经训练了一个神经网络,它将4个浮点值作为输入,并为四个类标签返回一个热编码输出。
例如,{2,12,30,4} - > {0,0,1,0}
生成训练模型并将其保存在.pb文件中。然后将模型导入我的Android应用程序的资产文件夹:
inferenceInterface = new TensorFlowInferenceInterface(getAssets(), "tensorflow_lite_xor_nn.pb");
我有以下功能:
private float[] predict(float[] input){
float output[] = new float[4];
inferenceInterface.feed("dense_1_input", input, 4, input.length);
inferenceInterface.run(new String[]{"dense_2/Sigmoid"});
inferenceInterface.fetch("dense_2/Sigmoid", output);
return output;
}
但是我收到了这个错误:
java.lang.IllegalArgumentException:包含4个元素的缓冲区与具有形状的Tensor不兼容[4,4]