如何从内部存储加载Tensorflow模型?

时间:2019-07-09 20:26:37

标签: java android tensorflow

我想知道是否可以从Android设备的内部存储而不是资产文件夹中存储和读取经过训练的.tflite模型?

下面是用于从Assets文件夹加载模型的原始代码(有效)。

private MappedByteBuffer loadLocalModelFile() throws IOException {
  AssetFileDescriptor fileDescriptor = getAssets().openFd(MODEL_PATH);
  FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
  long startOffset = fileDescriptor.getStartOffset();
  long declaredLength = fileDescriptor.getDeclaredLength();

  FileChannel fileChannel = inputStream.getChannel();
  return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}

是否有一种方法可以从内部存储器加载模型,并且仍然为fileChannel.map(FileChannel.MapMode.READ_ONLY,startOffset,clarifiedLength)获取startOffset和clarifiedLength?如果没有,从内部存储读取原始二进制文件时,是否有一种方法可以计算新模型的startOffset及其声明的长度?

我尝试使用AssetManager中的openNonAssetFd()函数来获取位于内部存储中的文件的AssetFileDescriptor。

private MappedByteBuffer loadOnlineModelFile() throws IOException {
    FileInputStream inputStream = openFileInput(MODEL);

    AssetManager manager = getAssets();
    AssetFileDescriptor afd = manager.openNonAssetFd(getFilesDir() + "/graph.lite");

    long startOffset = afd.getStartOffset();
    long declaredLength = afd.getDeclaredLength();

    FileChannel fileChannel = inputStream.getChannel();
    return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
  }

但是,这将导致“ java.lang.IllegalArgumentException:模型ByteBuffer应该是模型文件的MappedByteBuffer,或者是使用包含模型内容字节的ByteOrder.nativeOrder()的直接ByteBuffer”和“java.io。 FileNotFoundException”。

2 个答案:

答案 0 :(得分:0)

好吧,我一直在到处搜索,终于找到答案了。很简单。
由于某种原因,我认为AssetFileDescriptor的{​​{1}}与实际的getStartOffset相关,但与实际无关。我认为tflite model在应用程序资产中提供了文件的getStartOffset点。对于starttflite model应该是startOffset,因为这是文件开始的地方,因为它只是一个文件。 因此,代码应为

0

答案 1 :(得分:0)

您可以直接从内部存储访问文件。 这是一个演示代码,用于从位于内部存储中的示例文件夹中读取名为 model.tflite 的 tflite 模型。

 @NonNull
  public MappedByteBuffer loadMappedFile(@NonNull Context context, @NonNull String filePath) throws IOException {
    SupportPreconditions.checkNotNull(context, "Context should not be null.");
    SupportPreconditions.checkNotNull(filePath, "File path cannot be null.");
    File file = new File(Environment.getExternalStorageDirectory() + "/sample/" + filePath);

    MappedByteBuffer var9;
    try {
      FileInputStream inputStream = new FileInputStream(file);
      try {
        FileChannel fileChannel = inputStream.getChannel();
        var9 = fileChannel.map(FileChannel.MapMode.READ_ONLY, 0, file.length());
      } catch (Throwable var12) {
        try {
          inputStream.close();
        } catch (Throwable var11) {
          var12.addSuppressed(var11);
        }
        throw var12;
      }

      inputStream.close();
    } catch (Throwable var13) {
      throw var13;
    }

    return var9;
  }

文件路径将是模型的名称。这是model.tflite。 我们可以这样调用方法,

loadMappedFile(Classifier.this, "model.tflite");