从Java中的tflite模型推断

时间:2019-05-09 14:22:24

标签: java tensorflow

我已经导出了一个tflite模型,并在this链接上使用了Python代码,因此能够从该模型进行推断。但是,现在我正在尝试使用Java在Android应用中进行推理。我一直在关注官方文档here,但无法正常使用。有人可以指导我如何完成它吗?我所需要的只是以下步骤。

  1. 读取tflite模型。
  2. 生成形状为[1,200,3]的虚拟记录
  3. tflite模型中获得推断并进行打印。

我一直在阅读tflite演示,但仍然无法解决。要加载模型,我使用

Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model)

从官方文档中获取以下错误:

error: no suitable constructor found for Interpreter(String)
constructor Interpreter.Interpreter(File) is not applicable
  (argument mismatch; String cannot be converted to File)

我无法解决。我该怎么做这个简单的任务?

2 个答案:

答案 0 :(得分:1)

您可以将TFLite模型粘贴到应用程序的资产文件夹中。然后,使用此代码加载其MappedByteBuffer

private MappedByteBuffer loadModelFile() throws IOException {
        String MODEL_ASSETS_PATH = "recog_model.tflite";
        AssetFileDescriptor assetFileDescriptor = context.getAssets().openFd(MODEL_ASSETS_PATH) ;
        FileInputStream fileInputStream = new FileInputStream( assetFileDescriptor.getFileDescriptor() ) ;
        FileChannel fileChannel = fileInputStream.getChannel() ;
        long startoffset = assetFileDescriptor.getStartOffset() ;
        long declaredLength = assetFileDescriptor.getDeclaredLength() ;
        return fileChannel.map( FileChannel.MapMode.READ_ONLY , startoffset , declaredLength ) ;
}

然后在构造函数中调用它。

Interpreter interpreter = new Interpreter( loadModelFile() )

答案 1 :(得分:0)

我找到了解决方案。问题是,new Interpreter(file_of_a_tensorflowlite_model)没有使用字符串文件名作为输入。您必须将其设置为MappedByteBuffer并将其传递给Interpreter

new Interpreter(my_byte_buffer_method(abc.tflite))

之后可以正常工作。如果其他人遇到相同的问题,则只需发布。