Tensorflow对象检测-避免重叠框

时间:2019-02-05 16:08:06

标签: python object tensorflow detection

简介: 我是机器学习的新手,我和一位同事必须实现一种用于检测交通信号灯的算法。我下载了一个预先训练的模型(更快的rcnn),并运行了几个训练步骤(〜10000)。现在,当使用来自tensorflow git存储库的对象检测算法时,会在一个区域中检测到多个交通信号灯。

我做了一些研究,发现了函数“ tf.image.non_max_suppression”,但是我无法使其按预期工作(说实话,我什至无法运行它)。

我假设您知道tf对象检测示例代码,因此您还知道所有框都是使用字典(output_dict)返回的。

“清理”我使用的盒子:

selected_indices = tf.image.non_max_suppression(
        boxes           = output_dict['detection_boxes'],
        scores          = output_dict['detection_scores'],
        max_output_size = 1,
        iou_threshold   = 0.5,
        score_threshold = float('-inf'),
        name            = None)

起初我以为可以将selected_indices用作新的盒子列表,所以我尝试了以下方法:

vis_util.visualize_boxes_and_labels_on_image_array(
      image                      = image_np,
      boxes                      = selected_indices,
      classes                    = output_dict['detection_classes'],
      scores                     = output_dict['detection_scores'],
      category_index             = category_index,
      instance_masks             = output_dict.get('detection_masks'),
      use_normalized_coordinates = True)

但是当我注意到这行不通时,我发现了一个必需的方法:“ tf.gather()”。然后我运行以下代码:

boxes = output_dict['detection_boxes']
selected_indices = tf.image.non_max_suppression(
    boxes           = boxes,
    scores          = output_dict['detection_scores'],
    max_output_size = 1,
    iou_threshold   = 0.5,
    score_threshold = float('-inf'),
    name            = None)

selected_boxes = tf.gather(boxes, selected_indices)

vis_util.visualize_boxes_and_labels_on_image_array(
      image                      = image_np,
      boxes                      = selected_boxes,
      classes                    = output_dict['detection_classes'],
      scores                     = output_dict['detection_scores'],
      category_index             = category_index,
      instance_masks             = output_dict.get('detection_masks'),
      use_normalized_coordinates = True)

但没有一个可行。我在第689行的visualisation_utils.py中收到AttributeError(“ Tensor”对象没有属性“ tolist”)。

1 个答案:

答案 0 :(得分:0)

因此,看起来要以正确的格式获取框,您需要创建一个会话并评估张量,如下所示:

suppressed = tf.image.non_max_suppression(output_dict['detection_boxes'], output_dict['detection_scores'], 5) # Replace 5 with max num desired boxes

sboxes = tf.gather(output_dict['detection_boxes'], suppressed)
sscores = tf.gather(output_dict['detection_scores'], suppressed)
sclasses = tf.gather(output_dict['detection_classes'], suppressed)

sess = tf.Session()
with sess.as_default():
    boxes = sboxes.eval()
    scores =sscores.eval()
    classes = sclasses.eval()

vis_util.visualize_boxes_and_labels_on_image_array(
      image_np,
      boxes,
      classes,
      scores,
      category_index,
      instance_masks=output_dict.get('detection_masks'),
      use_normalized_coordinates=True,
      line_thickness=8)