我如何准确地检索使用Tensorflow对象检测API检测到的对象的边界框?

时间:2019-06-17 20:11:47

标签: python opencv tensorflow

我正在尝试了解在检测到对象时如何找到边界框的位置。我使用Tensorflow对象检测API来检测框中的鼠标。仅出于测试如何检索边界框坐标的目的,当检测到鼠标时,我想在其头部上方打印“此鼠标”。但是,我的目前正在打印几英寸的离板机。例如,这是我的物体检测视频的屏幕截图。

screenshot

以下是相关代码段:

with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
    start = time.time()
    while True:

        # Read frame from camera
        ret, image_np = cap.read()

        cv2.putText(image_np, "Time Elapsed: {}s".format(int(time.time() - start)), (50,50),cv2.FONT_HERSHEY_PLAIN,3, (0,0,255),3)
        # Expand dimensions since the model expects images to have shape: [1, None, None, 3]
        image_np_expanded = np.expand_dims(image_np, axis=0)
        # Extract image tensor
        image_tensor = detection_graph.get_tensor_by_name('image_tensor:0')
        # Extract detection boxes
        boxes = detection_graph.get_tensor_by_name('detection_boxes:0')
        # Extract detection scores
        scores = detection_graph.get_tensor_by_name('detection_scores:0')
        # Extract detection classes
        classes = detection_graph.get_tensor_by_name('detection_classes:0')
        # Extract number of detectionsd
        num_detections = detection_graph.get_tensor_by_name(
            'num_detections:0')
        # Actual detection.
        (boxes, scores, classes, num_detections) = sess.run(
            [boxes, scores, classes, num_detections],
            feed_dict={image_tensor: image_np_expanded})
        # Visualization of the results of a detection.
        vis_util.visualize_boxes_and_labels_on_image_array(
            image_np,
            np.squeeze(boxes),
            np.squeeze(classes).astype(np.int32),
            np.squeeze(scores),
            category_index,
            use_normalized_coordinates=True,
            line_thickness=8)

        for i, b in enumerate(boxes[0]):
            if classes[0][i] == 1:
                if scores[0][i] >= .5:
                    mid_x = (boxes[0][i][3] + boxes[0][i][1]) / 2
                    mid_y = (boxes[0][i][2] + boxes[0][i][0]) / 2


                    cv2.putText(image_np, 'FOUND A MOUSE', (int(mid_x*600), int(mid_y*800)), cv2.FONT_HERSHEY_PLAIN, 2, (0,255,0), 3)

        # Display output
        cv2.imshow(vid_name, cv2.resize(image_np, (800, 600)))

        #Write to output
        video_writer.write(image_np)

        if cv2.waitKey(25) & 0xFF == ord('q'):
            cv2.destroyAllWindows()
            break


    cap.release()
    cv2.destroyAllWindows()

我还不清楚boxes的工作方式。有人可以向我解释这一行吗:mid_x = (boxes[0][i][3] + boxes[0][i][1]) / 2?我知道3和1索引分别代表x_minx_max,但是我不确定为什么只遍历box [0]和i代表什么。

解决方案正如ievbu所建议的那样,我需要将中点计算从其标准化值转换为框架值。我找到了一个返回宽度和高度的cv2函数,并使用这些值将中点转换为像素位置。

frame_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
frame_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))

...
cv2.putText(image_np, '.', (int(mid_x*frame_w), int(mid_y*frame_h)), cv2.FONT_HERSHEY_PLAIN, 2, (0,255,0), 3)

1 个答案:

答案 0 :(得分:0)

以较高的维度返回框,因为您可以给出多个图像,然后该维度将代表每个单独的图像(对于一个输入图像,您可以使用np.expand_dims来扩展尺寸)。您可以看到,为进行可视化,使用np.squeeze将其删除,如果仅处理1张图像,则只需拍摄boxes[0]即可手动将其删除。 i代表Box in Boxs数组的索引,您需要该索引来访问所分析的Box的类和分数。

文本位置不正确,因为返回的框坐标已标准化,您必须将其转换为匹配完整图像尺寸。这是如何转换它们的示例:

(im_width, im_height, _) = frame.shape
xmin, ymin, xmax, ymax = box
(xmin, xmax, ymin, ymax) = (xmin * im_width, xmax * im_width,
                            ymin * im_height, ymax * im_height)