Keras深度学习模型到android

时间:2017-08-25 03:59:44

标签: android tensorflow deep-learning keras

我正在为android开发一个实时对象分类应用程序。首先,我使用“keras”创建了一个深度学习模型,我已经将训练过的模型保存为“model.h5”文件。我想知道如何在android中使用该模型进行图像分类。

3 个答案:

答案 0 :(得分:7)

您无法将Keras直接导出到Android,但您必须保存模型

  • 将Tensflow配置为Keras后端。

  • 使用model.save(filepath)保存模特摔角(你已经这样做了)

然后使用以下解决方案之一加载它:

解决方案1:在Tensflow中导入模型

1-构建Tensorflow模型

  • 根据keras模型use this code构建张量流模型(链接更新)

2-构建Android应用并调用tensflow。从谷歌查看此tutorial和此official demo,了解如何操作。

解决方案2:在java中导入模型
1- deeplearning4j java库允许导入keras模型:tutorial link
2-在Android中使用deeplearning4j:因为你在java世界中很容易。检查this tutorial

答案 1 :(得分:1)

首先,您需要将 Keras 模型导出为 Tensorflow 模型:

def export_model_for_mobile(model_name, input_node_names, output_node_name):
    tf.train.write_graph(K.get_session().graph_def, 'out', \
        model_name + '_graph.pbtxt')

    tf.train.Saver().save(K.get_session(), 'out/' + model_name + '.chkp')

    freeze_graph.freeze_graph('out/' + model_name + '_graph.pbtxt', None, \
        False, 'out/' + model_name + '.chkp', output_node_name, \
        "save/restore_all", "save/Const:0", \
        'out/frozen_' + model_name + '.pb', True, "")

    input_graph_def = tf.GraphDef()
    with tf.gfile.Open('out/frozen_' + model_name + '.pb', "rb") as f:
        input_graph_def.ParseFromString(f.read())

    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
            input_graph_def, input_node_names, [output_node_name],
            tf.float32.as_datatype_enum)

    with tf.gfile.FastGFile('out/tensorflow_lite_' + model_name + '.pb', "wb") as f:
        f.write(output_graph_def.SerializeToString())

您只需要了解图表的input_nodes_namesoutput_node_names即可。这将创建一个包含多个文件的新文件夹。其中,一个以 tensorflow_lite_ 开头。这是您要移至Android设备的文件。

然后在Android上导入Tensorflow库并使用 TensorFlowInferenceInterface 来运行您的模型。

implementation 'org.tensorflow:tensorflow-android:1.5.0'

您可以在Github上查看我的简单XOR示例:

https://github.com/OmarAflak/Keras-Android-XOR

答案 2 :(得分:0)

如果您想优化分类方法,那么我建议您使用armnn android库对模型进行推断。

您必须遵循几个步骤。 1.在ubuntu中安装和设置arm nn库。您可以从以下网址获取帮助

https://github.com/ARM-software/armnn/blob/branches/armnn_19_08/BuildGuideAndroidNDK.md

  1. 只需导入您的模型并进行推断。您可以从网址下方获取帮助

https://developer.arm.com/solutions/machine-learning-on-arm/developer-material/how-to-guides/deploying-a-tensorflow-mnist-model-on-arm-nn/deploying-a-tensorflow-mnist-model-on-arm-nn-single-page

  1. 编译后,您将获得二进制文件,该二进制文件将接受输入并为您提供输出

  2. 您可以在任何Andriod应用程序中运行该二进制文件

这是优化方式。