TensorFlowInferenceInterface具有onehot编码输出

时间:2018-05-04 08:48:07

标签: android neural-network tensorflow-lite

我已经训练了一个神经网络,它将4个浮点值作为输入,并为四个类标签返回一个热编码输出。

例如,{2,12,30,4} - > {0,0,1,0}

生成训练模型并将其保存在.pb文件中。然后将模型导入我的Android应用程序的资产文件夹:

inferenceInterface = new TensorFlowInferenceInterface(getAssets(), "tensorflow_lite_xor_nn.pb");

我有以下功能:

private float[] predict(float[] input){
    float output[] = new float[4];

    inferenceInterface.feed("dense_1_input", input, 4, input.length);
    inferenceInterface.run(new String[]{"dense_2/Sigmoid"});
    inferenceInterface.fetch("dense_2/Sigmoid", output);

    return output;
}

但是我收到了这个错误:

  

java.lang.IllegalArgumentException:包含4个元素的缓冲区与具有形状的Tensor不兼容[4,4]

0 个答案:

没有答案