建立tflite模型进行多类别分类

时间:2019-07-31 02:28:24

标签: android tensorflow-lite

我已经阅读了多个代码实验室,其中Google对属于一个类别的图像进行了分类。如果我需要使用2个或更多类,该怎么办。例如,如果我要分类图像是否包含水果或蔬菜,然后分类是水果或蔬菜的类型。

1 个答案:

答案 0 :(得分:1)

您可以使用TensorFlow(特别是使用Keras)轻松训练卷积神经网络(CNN)。互联网上有很多例子。参见herehere

接下来,我们使用.h5将Keras保存的模型(.tflite文件)转换为tf.lite.TFLiteConverter文件,

import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_keras_model_file("keras_model.h5")
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)

请参见here

现在,在Android中,我们拍摄一张Bitmap图像并将其转换为float[][][][]

private float[][][][] convertImageToFloatArray ( Bitmap image ) {
   float[][][][] imageArray = new   float[1][modelInputDim][modelInputDim][1] ;
   for ( int x = 0 ; x < modelInputDim ; x ++ ) {
       for ( int y = 0 ; y < modelInputDim ; y ++ ) {
           float R = ( float )Color.red( image.getPixel( x , y ) );
           float G = ( float )Color.green( image.getPixel( x , y ) );
           float B = ( float )Color.blue( image.getPixel( x , y ) );
           double grayscalePixel = (( 0.3 * R ) + ( 0.59 * G ) + ( 0.11 * B )) / 255;
           imageArray[0][x][y][0] = (float)grayscalePixel ;
       }
   }
   return imageArray ;
}

modelInputDim是图像的模型输入大小。上面的代码片段将RGB图像转换为灰度图像。

现在,我们执行最终推断,

private int modelInputDim = 28 ;
private int outputDim = 3 ;

private float[] performInference(Bitmap frame , RectF cropImageRectF ) {
   Bitmap croppedBitmap = getCroppedBitmap( frame , cropImageRectF ) ;
   Bitmap croppedFrame = resizeBitmap( croppedBitmap );
   float[][][][] imageArray = convertImageToFloatArray( croppedFrame ) ;
   float[][] outputArray = new float[1][outputDim] ;
   interpreter.run( imageArray , outputArray ) ;
   return outputArray[0] ;
}

我准备了一系列Android应用,这些应用在Android中使用了TFLite模型。参见here