如何在Android中执行张量流模型的推理

时间:2018-10-01 03:57:04

标签: android tensorflow neural-network tensorflow-lite

我尝试使用Tensorflow Lite,但它有很多限制,它没有批处理规范化操作,即使使用简单的操作,它也为使用Keras测试的相同数据提供了非常奇怪的结果。这意味着使用keras一切正常,使用tensorflow lite,结果是完全错误的。所以我需要一些东西才能在Android上执行.pb文件。

1 个答案:

答案 0 :(得分:2)

您可以使用TensorFlowInferenceInterface通过.pb文件进行预测。首先,将.pb文件放在应用程序的资产文件夹中。

  1. 在您的build.gradle(Module:app)文件中,添加以下依赖项, implementation 'org.tensorflow:tensorflow-android:1.11.0'
  2. 然后初始化TensorFlowInferenceInterface,如果您的模型文件名为“ model.pb”,则, TensorFlowInferenceInterface tensorFlowInferenceInterface = new TensorFlowInferenceInterface(context.getAssets() , "file:///android_asset/model.pb") ;
  3. tensorFlowInferenceInterface.feed( INPUT_NAME , inputs , 1, 28, 28);,其中INPUT_NAME是输入层的名称。 1 , 50是输入尺寸。

  4. tensorFlowInferenceInterface.run( new String[]{ OUTPUT_NAME } );,其中OUTPUT_NAME是您的输出图层的名称。

  5. float[] outputs = new float[ nuymber_of_classes ]; tensorFlowInferenceInterface.fetch( OUTPUT_NAME , outputs ) ;

outputs是根据模型预测的浮点值。

这是完整的代码:

TensorFlowInferenceInterface tensorFlowInferenceInterface = new 
TensorFlowInferenceInterface(context.getAssets() , "file:///android_asset/model.pb");
tensorFlowInferenceInterface.feed( INPUT_NAME , inputs , 1, 28, 28);
tensorFlowInferenceInterface.run( new String[]{ OUTPUT_NAME } );
float[] outputs = new float[ nuymber_of_classes ];
tensorFlowInferenceInterface.fetch( OUTPUT_NAME , outputs ) ;