如何使用Mask R-CNN模型将一批图像传递给Tensorflows对象检测API

时间:2018-07-04 08:14:04

标签: python tensorflow image-processing object-detection object-detection-api

我正在使用tensorflows对象检测api,并且出于我的目的操纵了object_destection_tutorial笔记本。为了加快计算速度,我想处理一批图像,那么当加载的模型为mask_rcnn模型时,如何将一批图像传递给sess.run()命令?批处理形状为(batch_size,image_width,image_height,3),并且使用常规的fast_rcnn体系结构,它的工作原理如下:

output_dict = sess.run(tensor_dict, feed_dict={image_tensor: image_batch})

但是,如果我将模型更改为mask_rcnn架构,它将不允许任何批处理,该批处理大于(1,image_width,image_height,3),否则将返回例如以下错误消息:

InvalidArgumentError (see above for traceback): Tried to explicitly squeeze dimension 0 but dimension was not 1: 2
     [[Node: Squeeze_5 = Squeeze[T=DT_FLOAT, squeeze_dims=[0], _device="/job:localhost/replica:0/task:0/device:CPU:0"](detection_boxes)]]

我知道它正在尝试将image_batch压缩回尺寸(image_width,image_height,3),该尺寸仅在batch_size为1时有效。但是如何避免这种情况呢?我认为关键是要操纵这些代码行(尤其是最后两行),因为它们似乎为sess.run()命令创建了一个空的output_dict(对吗?):

if 'detection_masks' in tensor_dict:
        # The following processing is only for single image
        detection_boxes = tf.squeeze(tensor_dict['detection_boxes'], [0])
        detection_masks = tf.squeeze(tensor_dict['detection_masks'], [0])
        # Reframe is required to translate mask from box coordinates to image coordinates and fit the image size.
        real_num_detection = tf.cast(tensor_dict['num_detections'][0], tf.int32)
        detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1])
        detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1])
        detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(
            detection_masks, detection_boxes, image_batch.shape[1], image_batch.shape[2])
        detection_masks_reframed = tf.cast(
            tf.greater(detection_masks_reframed, 0.5), tf.uint8)
        # Follow the convention by adding back the batch dimension
        tensor_dict['detection_masks'] = tf.expand_dims(
            detection_masks_reframed, 0)

0 个答案:

没有答案