posenet返回什么?

时间:2019-05-11 02:25:45

标签: java android tensorflow image-processing tensorflow-lite

我正在研究一个项目,该项目读取图像作为输入并显示和输出图像。输出图像包含一些线条以指示人体骨骼。我正在使用来自tensorflow-lite的姿势估计模型:

https://www.tensorflow.org/lite/models/pose_estimation/overview

我已经阅读了文档,它显示输出包含一个4维数组。我尝试使用netron可视化我的模型文件,它看起来像这样: model visualisation

我成功地从输入中获得了结果热图,但是我遇到了一个问题,即所有浮点均为负数。这让我感到困惑,我不确定是否做错了什么或如何理解这些输出。

这是输出的代码

            tfLite = new Interpreter(loadModelFile());
            Bitmap inputPhoto = BitmapFactory.decodeResource(getResources(), R.drawable.human2);
            inputPhoto = Bitmap.createScaledBitmap(inputPhoto, INPUT_SIZE_X, INPUT_SIZE_Y, false);
            inputPhoto = inputPhoto.copy(Bitmap.Config.ARGB_8888, true);

            int pixels[] = new int[INPUT_SIZE_X * INPUT_SIZE_Y];

            inputPhoto.getPixels(pixels, 0, INPUT_SIZE_X, 0, 0, INPUT_SIZE_X, INPUT_SIZE_Y);

            int pixelsIndex = 0;

            for (int i = 0; i < INPUT_SIZE_X; i ++) {
                for (int j = 0; j < INPUT_SIZE_Y; j++) {
                    int p = pixels[pixelsIndex];
                    inputData[0][i][j][0] = (p >> 16) & 0xff;
                    inputData[0][i][j][1] = (p >> 8) & 0xff;
                    inputData[0][i][j][2] = (p) & 0xff;
                    pixelsIndex ++;
                }
            }

            float outputData[][][][] = new float[1][23][17][17];

            tfLite.run(inputData, outputData);

输出为数组[1] [23] [17] [17],它们均为负数。那么有谁知道这件事可以帮助我:(

非常感谢!

1 个答案:

答案 0 :(得分:1)

今天该帖子发布了,所以我发布了较晚的答案,对此感到抱歉。
您应该检查Posenet.kt file。在这里,您可以看到非常详细的文档化代码。您可以看到以下内容:

初始化一个1 * x * y * z FloatArrays的outputMap,以填充模型处理。 * /

private fun initOutputMap(interpreter: Interpreter): HashMap<Int, Any> {
    val outputMap = HashMap<Int, Any>()

    // 1 * 9 * 9 * 17 contains heatmaps
    val heatmapsShape = interpreter.getOutputTensor(0).shape()
    outputMap[0] = Array(heatmapsShape[0]) {
      Array(heatmapsShape[1]) {
        Array(heatmapsShape[2]) { FloatArray(heatmapsShape[3]) }
      }
    }

    // 1 * 9 * 9 * 34 contains offsets
    val offsetsShape = interpreter.getOutputTensor(1).shape()
    outputMap[1] = Array(offsetsShape[0]) {
      Array(offsetsShape[1]) { Array(offsetsShape[2]) { FloatArray(offsetsShape[3]) } }
    }

    // 1 * 9 * 9 * 32 contains forward displacements
    val displacementsFwdShape = interpreter.getOutputTensor(2).shape()
    outputMap[2] = Array(offsetsShape[0]) {
      Array(displacementsFwdShape[1]) {
        Array(displacementsFwdShape[2]) { FloatArray(displacementsFwdShape[3]) }
      }
    }

    // 1 * 9 * 9 * 32 contains backward displacements
    val displacementsBwdShape = interpreter.getOutputTensor(3).shape()
    outputMap[3] = Array(displacementsBwdShape[0]) {
      Array(displacementsBwdShape[1]) {
        Array(displacementsBwdShape[2]) { FloatArray(displacementsBwdShape[3]) }
      }
    }

    return outputMap
    }

当然还有如何将输出转换为屏幕上的点:

/**
 * Estimates the pose for a single person.
 * args:
 *      bitmap: image bitmap of frame that should be processed
 * returns:
 *      person: a Person object containing data about keypoint locations and confidence scores
 */
fun estimateSinglePose(bitmap: Bitmap): Person {
    val estimationStartTimeNanos = SystemClock.elapsedRealtimeNanos()
    val inputArray = arrayOf(initInputArray(bitmap))
    Log.i(
        "posenet",
        String.format(
            "Scaling to [-1,1] took %.2f ms",
            1.0f * (SystemClock.elapsedRealtimeNanos() - estimationStartTimeNanos) / 1_000_000
        )
    )

    val outputMap = initOutputMap(getInterpreter())

    val inferenceStartTimeNanos = SystemClock.elapsedRealtimeNanos()
    getInterpreter().runForMultipleInputsOutputs(inputArray, outputMap)
    lastInferenceTimeNanos = SystemClock.elapsedRealtimeNanos() - inferenceStartTimeNanos
    Log.i(
        "posenet",
        String.format("Interpreter took %.2f ms", 1.0f * lastInferenceTimeNanos / 1_000_000)
    )

    val heatmaps = outputMap[0] as Array<Array<Array<FloatArray>>>
    val offsets = outputMap[1] as Array<Array<Array<FloatArray>>>

    val height = heatmaps[0].size
    val width = heatmaps[0][0].size
    val numKeypoints = heatmaps[0][0][0].size

    // Finds the (row, col) locations of where the keypoints are most likely to be.
    val keypointPositions = Array(numKeypoints) { Pair(0, 0) }
    for (keypoint in 0 until numKeypoints) {
        var maxVal = heatmaps[0][0][0][keypoint]
        var maxRow = 0
        var maxCol = 0
        for (row in 0 until height) {
            for (col in 0 until width) {
                if (heatmaps[0][row][col][keypoint] > maxVal) {
                    maxVal = heatmaps[0][row][col][keypoint]
                    maxRow = row
                    maxCol = col
                }
            }
        }
        keypointPositions[keypoint] = Pair(maxRow, maxCol)
    }

    // Calculating the x and y coordinates of the keypoints with offset adjustment.
    val xCoords = IntArray(numKeypoints)
    val yCoords = IntArray(numKeypoints)
    val confidenceScores = FloatArray(numKeypoints)
    keypointPositions.forEachIndexed { idx, position ->
        val positionY = keypointPositions[idx].first
        val positionX = keypointPositions[idx].second
        yCoords[idx] = (
                position.first / (height - 1).toFloat() * bitmap.height +
                        offsets[0][positionY][positionX][idx]
                ).toInt()
        xCoords[idx] = (
                position.second / (width - 1).toFloat() * bitmap.width +
                        offsets[0][positionY]
                                [positionX][idx + numKeypoints]
                ).toInt()
        confidenceScores[idx] = sigmoid(heatmaps[0][positionY][positionX][idx])
    }

    val person = Person()
    val keypointList = Array(numKeypoints) { KeyPoint() }
    var totalScore = 0.0f
    enumValues<BodyPart>().forEachIndexed { idx, it ->
        keypointList[idx].bodyPart = it
        keypointList[idx].position.x = xCoords[idx]
        keypointList[idx].position.y = yCoords[idx]
        keypointList[idx].score = confidenceScores[idx]
        totalScore += confidenceScores[idx]
    }

    person.keyPoints = keypointList.toList()
    person.score = totalScore / numKeypoints

    return person
}

整个.kt文件是位图的核心,指向屏幕上的各个点!

如果您还需要其他标签,请标记我。

快乐编码