使用deeplearning4j加载keras模型时出错

时间:2019-04-12 21:58:42

标签: android gradle android-gradle build.gradle deeplearning4j

一段时间以来,我一直在努力为Android应用程序使用deeplearning4j加载我的keras神经网络模型。我一直在寻找解决方案(尽可能多),但是每个解决方案都会带来新的错误,而我只是无法使这件事起作用。

无论如何,我已经在Python中使用keras训练了 NON 顺序模型,并像这样保存了它:

model.save('model.h5')

现在,我正在尝试使用Android Studio中的deeplearning4j导入此模型。我尝试了许多可能的变体,但这是我现在的位置:

String modelPath = new ClassPathResource("res/raw/model.h5").getFile().getPath();
ComputationGraph model = KerasModelImport.importKerasModelAndWeights(modelPath)

但这会触发以下错误:

java.lang.NoClassDefFoundError: Failed resolution of: Lorg/bytedeco/javacpp/hdf5;

据我了解,gradle无法解析来自hdf5的依赖项org.bytedeco,这是我同意的,因为我已在gradle构建中排除了hdf5-platform,但是hdf5就我所知(?),Android甚至不应该支持。

我还尝试过包含hdf5-platform并运行相同的代码,但是这样做会引发另一个错误:

java.lang.UnsatisfiedLinkError: Platform "android-arm64" not supported by class org.bytedeco.javacpp.hdf5

我对gradle概念还很陌生,我不深入了解Android,但是问题似乎出在我的gradle依赖项上。关于deeplearning4j的信息也非常有限,我也找不到其他解决方案。

我还将包括this tutorial.

中的gradle依赖项
implementation (group: 'org.deeplearning4j', name: 'deeplearning4j-core', version: '1.0.0-beta3') {
    exclude group: 'org.bytedeco.javacpp-presets', module: 'opencv-platform'
    exclude group: 'org.bytedeco.javacpp-presets', module: 'leptonica-platform'
    exclude group: 'org.bytedeco.javacpp-presets', module: 'hdf5-platform'
    exclude group: 'org.nd4j', module: 'nd4j-base64'
}
implementation group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta3'
implementation group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta3', classifier: "android-arm"
implementation group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta3', classifier: "android-arm64"
implementation group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta3', classifier: "android-x86"
implementation group: 'org.nd4j', name: 'nd4j-native', version: '1.0.0-beta3', classifier: "android-x86_64"
implementation group: 'org.bytedeco.javacpp-presets', name: 'openblas', version: '0.3.3-1.4.3'
implementation group: 'org.bytedeco.javacpp-presets', name: 'openblas', version: '0.3.3-1.4.3', classifier: "android-arm"
implementation group: 'org.bytedeco.javacpp-presets', name: 'openblas', version: '0.3.3-1.4.3', classifier: "android-arm64"
implementation group: 'org.bytedeco.javacpp-presets', name: 'openblas', version: '0.3.3-1.4.3', classifier: "android-x86"
implementation group: 'org.bytedeco.javacpp-presets', name: 'openblas', version: '0.3.3-1.4.3', classifier: "android-x86_64"
implementation group: 'org.bytedeco.javacpp-presets', name: 'opencv', version: '3.4.3-1.4.3'
implementation group: 'org.bytedeco.javacpp-presets', name: 'opencv', version: '3.4.3-1.4.3', classifier: "android-arm"
implementation group: 'org.bytedeco.javacpp-presets', name: 'opencv', version: '3.4.3-1.4.3', classifier: "android-arm64"
implementation group: 'org.bytedeco.javacpp-presets', name: 'opencv', version: '3.4.3-1.4.3', classifier: "android-x86"
implementation group: 'org.bytedeco.javacpp-presets', name: 'opencv', version: '3.4.3-1.4.3', classifier: "android-x86_64"
implementation group: 'org.bytedeco.javacpp-presets', name: 'leptonica', version: '1.76.0-1.4.3'
implementation group: 'org.bytedeco.javacpp-presets', name: 'leptonica', version: '1.76.0-1.4.3', classifier: "android-arm"
implementation group: 'org.bytedeco.javacpp-presets', name: 'leptonica', version: '1.76.0-1.4.3', classifier: "android-arm64"
implementation group: 'org.bytedeco.javacpp-presets', name: 'leptonica', version: '1.76.0-1.4.3', classifier: "android-x86"
implementation group: 'org.bytedeco.javacpp-presets', name: 'leptonica', version: '1.76.0-1.4.3', classifier: "android-x86_64"

(如何更改我的依赖关系,以使此模型能够正常工作?

还是应该以某种方式更改导入模型的方式?

1 个答案:

答案 0 :(得分:1)

deeplearning4j可能不是更好的选择。要在Android甚至iOS上加载TensorFlow Keras模型,可以使用TensorFlow Lite

首先,您需要将Keras(.h5)模型转换为TFLite模型(.tflite)

import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_keras_model_file( 'model.h5' )
tflite_model = converter.convert()
open( 'model.tflite' , 'wb' ).write( tflite_model )

您可以执行以下操作:

  1. 如果您的模型需要托管在Clod源上(可通过您的应用下载),则可以使用Firebase ML Kit。对于自定义的TFLite模型,请阅读here

  2. 您可以将TFLite模型保留在应用程序的资产文件夹中,然后加载其MappedByteBuffer。可以使用Android的TensorFlow Lite依赖项:

    implementation ‘org.tensorflow:tensorflow-lite:1.13.1’
    

您可以参考此codelab和此article

您可以像这样加载MappedByteBuffer:

private MappedByteBuffer loadModelFile(Activity activity) throws IOException {
  AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(getModelPath());
  FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
  FileChannel fileChannel = inputStream.getChannel();
  long startOffset = fileDescriptor.getStartOffset();
  long declaredLength = fileDescriptor.getDeclaredLength();
  return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}