我正在使用tensorflows对象检测api,并且出于我的目的操纵了object_destection_tutorial笔记本。为了加快计算速度,我想处理一批图像,那么当加载的模型为mask_rcnn模型时,如何将一批图像传递给sess.run()
命令?批处理形状为(batch_size,image_width,image_height,3),并且使用常规的fast_rcnn体系结构,它的工作原理如下:
output_dict = sess.run(tensor_dict, feed_dict={image_tensor: image_batch})
但是,如果我将模型更改为mask_rcnn架构,它将不允许任何批处理,该批处理大于(1,image_width,image_height,3),否则将返回例如以下错误消息:
InvalidArgumentError (see above for traceback): Tried to explicitly squeeze dimension 0 but dimension was not 1: 2
[[Node: Squeeze_5 = Squeeze[T=DT_FLOAT, squeeze_dims=[0], _device="/job:localhost/replica:0/task:0/device:CPU:0"](detection_boxes)]]
我知道它正在尝试将image_batch压缩回尺寸(image_width,image_height,3),该尺寸仅在batch_size为1时有效。但是如何避免这种情况呢?我认为关键是要操纵这些代码行(尤其是最后两行),因为它们似乎为sess.run()
命令创建了一个空的output_dict(对吗?):
if 'detection_masks' in tensor_dict:
# The following processing is only for single image
detection_boxes = tf.squeeze(tensor_dict['detection_boxes'], [0])
detection_masks = tf.squeeze(tensor_dict['detection_masks'], [0])
# Reframe is required to translate mask from box coordinates to image coordinates and fit the image size.
real_num_detection = tf.cast(tensor_dict['num_detections'][0], tf.int32)
detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1])
detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1])
detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(
detection_masks, detection_boxes, image_batch.shape[1], image_batch.shape[2])
detection_masks_reframed = tf.cast(
tf.greater(detection_masks_reframed, 0.5), tf.uint8)
# Follow the convention by adding back the batch dimension
tensor_dict['detection_masks'] = tf.expand_dims(
detection_masks_reframed, 0)