建立TF2 Model Zoo模型并还原提供的检查点后检测不良

时间:2020-09-05 14:12:10

标签: python tensorflow object-detection

说明

我遵循了TF2 OD API教程。我在下面尝试了2种型号

http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz http://download.tensorflow.org/models/object_detection/tf2/20200711/efficientdet_d1_coco17_tpu-32.tar.gz

例如,我同时使用了ssd_mobilenet_v1_fpn_640x640_coco17_tpu-8.config和相应的pipeline.config(文件有一些区别)。

当恢复检查点时,在models / research / object_detection / test_images图像1和2上的检测结果较差,但在同一gz文件中使用model.load时,检测结果良好。

我希望您能提供一些指导。

代码段


model_name =  'ssd_resnet50_v1_fpn_640x640_coco17_tpu-8'

if tf2_config:
   pipeline_config = os.path.join('models/research/object_detection/configs/tf2/',
                                model_name + '.config') 
else:
   pipeline_config = 'ssd_resnet50_v1_fpn_640x640_coco17_tpu-8/pipeline.config'

model_dir = 'ssd_resnet50_v1_fpn_640x640_coco17_tpu-8/checkpoint/'

# Load pipeline config and build a detection model
configs = config_util.get_configs_from_pipeline_file(pipeline_config)
model_config = configs['model']

detection_model_bld = model_builder.build(model_config=model_config, is_training=False)

# Restore checkpoint
ckpt = tf.compat.v2.train.Checkpoint(model=detection_model_bld)
ckpt.restore(os.path.join(model_dir, 'ckpt-0')).expect_partial()

def get_model_detection_function(model):
  """Get a tf.function for detection."""

  @tf.function
  def detect_fn(image):
    """Detect objects in image."""

    image, shapes = model.preprocess(image)
    prediction_dict = model.predict(image, shapes)
    detections = model.postprocess(prediction_dict, shapes)

    return detections, prediction_dict, tf.reshape(shapes, [-1])

  return detect_fn

detect_fn_ckpt = get_model_detection_function(detection_model_bld)

image_np = load_image_into_numpy_array(image_path)
  
  input_tensor = tf.convert_to_tensor(
    np.expand_dims(image_np, 0), dtype=tf.float32)
  
  start_time = time.time()
  detections, predictions_dict, shapes = detect_fn_ckpt(input_tensor)
  end_time = time.time()

0 个答案:

没有答案