通过Tensorflow进行图像分类可得出完全相同的预测

时间:2018-10-01 12:04:09

标签: android opencv tensorflow arcore tensorflow-lite

在我努力阐明的时候,请和我在一起。

我有一个Android应用程序,该应用程序使用OpenCV将YUV420图像转换为位图并将其传输到解释器。问题是,每次运行它时,我都会得到完全相同的类别预测,并且具有与我所指的内容完全相同的置信度值。

...
Recognitions : [macbook pro: 0.95353276, cello gripper: 0.023749515]. 
Recognitions : [macbook pro: 0.95353276, cello gripper: 0.023749515]. 
Recognitions : [macbook pro: 0.95353276, cello gripper: 0.023749515]. 
Recognitions : [macbook pro: 0.95353276, cello gripper: 0.023749515].
...

现在,在您提到我的模型训练不充分之前,我已经在Tensorflow Codelab-2中提供的TFLite示例中测试了完全相同的.tflite文件。它可以正常工作,并能以90%以上的精度识别我所有的4个班级。此外,我使用了label_image.py脚本来测试.pb所源自的.tflite文件,它可以正常工作。我已经在每个课程的近5000幅图像上训练了该模型。由于它可以在其他应用程序上使用,因此我猜模型并没有问题,但我的实现没有问题。尽管我只是无法查明。

以下代码用于根据图像字节创建Mat:

//Retrieve the camera Image from ARCore
val cameraImage = frame.acquireCameraImage()
val cameraPlaneY = cameraImage.planes[0].buffer
val cameraPlaneUV = cameraImage.planes[1].buffer

// Create a new Mat with OpenCV. One for each plane - Y and UV
val y_mat = Mat(cameraImage.height, cameraImage.width, CvType.CV_8UC1, cameraPlaneY)
val uv_mat = Mat(cameraImage.height / 2, cameraImage.width / 2, CvType.CV_8UC2, cameraPlaneUV)
var mat224 = Mat()
var cvFrameRGBA = Mat()

// Retrieve an RGBA frame from the produced YUV
Imgproc.cvtColorTwoPlane(y_mat, uv_mat, cvFrameRGBA, Imgproc.COLOR_YUV2BGRA_NV21)
// I've tried the following in the above line
// Imgproc.COLOR_YUV2RGBA_NV12
// Imgproc.COLOR_YUV2RGBA_NV21
// Imgproc.COLOR_YUV2BGRA_NV12
// Imgproc.COLOR_YUV2BGRA_NV21

以下代码用于将图像数据添加到ByteBuffer中:

// imageFrame is a Mat object created from OpenCV by processing a YUV420 image received from ARCore
override fun setImageFrame(imageFrame: Mat) {
    ...
    // Convert mat224 into a float array that can be sent to Tensorflow
    val rgbBytes: ByteBuffer = ByteBuffer.allocate(1 * 4 * 224 * 224 * 3)
    rgbBytes.order(ByteOrder.nativeOrder())

    val frameBitmap = Bitmap.createBitmap(imageFrame.cols(), imageFrame.rows(), Bitmap.Config.ARGB_8888, true)
    // convert Mat to Bitmap
    Utils.matToBitmap(imageFrame, frameBitmap, true)
    frameBitmap.getPixels(intValues, 0, frameBitmap.width, 0, 0, frameBitmap.width, frameBitmap.height)

    // Iterate over all pixels and retrieve information of RGB channels
    intValues.forEach { packedPixel ->
        rgbBytes.putFloat((((packedPixel shr 16) and 0xFF) - 128) / 128.0f)
        rgbBytes.putFloat((((packedPixel shr 8) and 0xFF) - 128) / 128.0f)
        rgbBytes.putFloat(((packedPixel and 0xFF) - 128) / 128.0f)
    }
}

.......
private var labelProb: Array<FloatArray>? = null
.......
// and classify 
labelProb?.let { interpreter?.run(rgbBytes, it) }
.......

我检查了从Mat转换的位图。它会尽可能地显示出来。

有任何想法吗?

更新一个

我略微更改了setImageFrame方法的实现,以匹配实现here。既然它对他有用,我希望它也对我有用。仍然没有。

override fun setImageFrame(imageFrame: Mat) {
    // Reset the rgb bytes buffer
    rgbBytes.rewind()

    // Iterate over all pixels and retrieve information of RGB channels only
    for(rows in 0 until imageFrame.rows())
        for(cols in 0 until imageFrame.cols()) {
            val imageData = imageFrame.get(rows, cols)
            // Type of Mat is 24
            // Channels is 4
            // Depth is 0
            rgbBytes.putFloat(imageData[0].toFloat())
            rgbBytes.putFloat(imageData[1].toFloat())
            rgbBytes.putFloat(imageData[2].toFloat())
        }
}

更新两个

对我的浮动模型有所怀疑,我将其更改为预先构建的MobileNet Quant模型,只是为了消除这种可能性。问题仍然存在。

...
Recognitions : [candle: 18.0, otterhound: 15.0, syringe: 13.0, English foxhound: 11.0]
Recognitions : [candle: 18.0, otterhound: 15.0, syringe: 13.0, English foxhound: 11.0]
Recognitions : [candle: 18.0, otterhound: 15.0, syringe: 13.0, English foxhound: 11.0]
Recognitions : [candle: 18.0, otterhound: 15.0, syringe: 13.0, English foxhound: 11.0]
...

1 个答案:

答案 0 :(得分:0)

好的。所以经过4天,我终于能够解决这个问题。问题是ByteBuffer是如何启动的。我在做:

private var rgbBytes: ByteBuffer = ByteBuffer.allocate(1 * 4 * 224 * 224 * 3)

代替我应该做的事情:

private val rgbBytes: ByteBuffer = ByteBuffer.allocateDirect(1 * 4 * 224 * 224 * 3)

我试图了解ByteBuffer.allocate()ByteBuffer.allocateDirect() here之间的区别,但无济于事。

如果有人能够回答另外两个问题,我会很高兴:

  1. 为什么Tensorflow需要直接字节缓冲区而不是非直接缓冲区?
  2. 在简化的描述中,直接字节缓冲区和非直接字节缓冲区有什么区别?