如何使用最新的MobileNet(v3)进行对象检测?

时间:2019-12-19 06:10:17

标签: java tensorflow object-detection tensorflow-lite mobilenet

我一直在尝试使用最新的MobileNet MobileNet_v3来运行对象检测。您可以从以下位置找到Google对此进行过预训练的模型,例如我正在尝试使用的模型“ ssd_mobilenet_v3_large_coco”:https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/detection_model_zoo.md

我不知道这些新模型如何获取图像数据输入,也找不到有关此在线内容的任何深入文档。以下Java代码总结了我如何尝试从网上可以收集的有限数量中获取模型数据(特别是使用TensorFlow Lite的.tflite模型)的图像数据,但是该模型仅返回10 ^ -20的预测置信度,因此它从不真正识别任何东西。由此我认为我一定做错了。

//Note that the model takes a 320 x 320 image


//Get image data as integer values
private int[] intValues;
intValues = new int[320 * 320];
private Bitmap croppedBitmap = null;
croppedBitmap = Bitmap.createBitmap(320, 320, Config.ARGB_8888);
croppedBitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());

//create ByteBuffer as input for running ssd_mobilenet_v3
private ByteBuffer imgData;
imgData = ByteBuffer.allocateDirect(320 * 320 * 3);
imgData.order(ByteOrder.nativeOrder());

//fill Bytebuffer
//Note that & 0xFF is for just getting the last 8 bits, which converts to RGB values here
imgData.rewind();
for (int i = 0; i < inputSize; ++i) {
  for (int j = 0; j < inputSize; ++j) {
    int pixelValue = intValues[i * inputSize + j];
    // Quantized model
    imgData.put((byte) ((pixelValue >> 16) & 0xFF));
    imgData.put((byte) ((pixelValue >> 8) & 0xFF));
    imgData.put((byte) (pixelValue & 0xFF));
  }
}

// Set up output buffers
private float[][][] output0;
private float[][][][] output1;
output0 = new float[1][2034][91];
output1 = new float[1][2034][1][4];

//Create input HashMap and run the model
Object[] inputArray = {imgData};
Map<Integer, Object> outputMap = new HashMap<>();
outputMap.put(0, output0);
outputMap.put(1, output1);
tfLite.runForMultipleInputsOutputs(inputArray, outputMap);

//Examine Confidences to see if any significant detentions were made
for (int i = 0; i < 2034; i++) {
  for (int j = 0; j < 91; j++) {
    System.out.println(output0[0][i][j]);
  }
}

1 个答案:

答案 0 :(得分:3)

我已经弄清楚了如何通过一些额外的努力使它工作。

您必须下载经过预先​​训练的模型,然后自己重​​新创建.tflite文件,以使其与提供的android代码一起使用。由Tensorflow团队编写的以下指南向您展示了如何重新创建.tflite文件,以使它们具有与android对象检测代码所接受的相同的输入/输出格式:

https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_on_mobile_tensorflowlite.md

这样,您几乎无需更改任何用于对象检测的android代码。您需要手动指定的唯一对象(在创建.tflite文件时以及在用于对象检测的android代码中)都是对象检测模型的分辨率。

因此,对于分辨率为320x320的mobilenet_v3,将模型转换为.tflite文件时,请使用标志“ --input_shapes = 1,320,320,3”。然后,在android代码中设置变量“ TF_OD_API_INPUT_SIZE = 320”。这些是您唯一需要做的更改。

从理论上讲,它可用于任何(且仅)ssd模型,但我目前仅使用mobilenet_v2对其进行了测试,因为它更易于使用并且v2和v3之间的差异可以忽略。