我遵循了TF2 OD API教程。我在下面尝试了2种型号
http://download.tensorflow.org/models/object_detection/tf2/20200711/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.tar.gz http://download.tensorflow.org/models/object_detection/tf2/20200711/efficientdet_d1_coco17_tpu-32.tar.gz
例如,我同时使用了ssd_mobilenet_v1_fpn_640x640_coco17_tpu-8.config和相应的pipeline.config(文件有一些区别)。
当恢复检查点时,在models / research / object_detection / test_images图像1和2上的检测结果较差,但在同一gz文件中使用model.load时,检测结果良好。
我希望您能提供一些指导。
model_name = 'ssd_resnet50_v1_fpn_640x640_coco17_tpu-8'
if tf2_config:
pipeline_config = os.path.join('models/research/object_detection/configs/tf2/',
model_name + '.config')
else:
pipeline_config = 'ssd_resnet50_v1_fpn_640x640_coco17_tpu-8/pipeline.config'
model_dir = 'ssd_resnet50_v1_fpn_640x640_coco17_tpu-8/checkpoint/'
# Load pipeline config and build a detection model
configs = config_util.get_configs_from_pipeline_file(pipeline_config)
model_config = configs['model']
detection_model_bld = model_builder.build(model_config=model_config, is_training=False)
# Restore checkpoint
ckpt = tf.compat.v2.train.Checkpoint(model=detection_model_bld)
ckpt.restore(os.path.join(model_dir, 'ckpt-0')).expect_partial()
def get_model_detection_function(model):
"""Get a tf.function for detection."""
@tf.function
def detect_fn(image):
"""Detect objects in image."""
image, shapes = model.preprocess(image)
prediction_dict = model.predict(image, shapes)
detections = model.postprocess(prediction_dict, shapes)
return detections, prediction_dict, tf.reshape(shapes, [-1])
return detect_fn
detect_fn_ckpt = get_model_detection_function(detection_model_bld)
image_np = load_image_into_numpy_array(image_path)
input_tensor = tf.convert_to_tensor(
np.expand_dims(image_np, 0), dtype=tf.float32)
start_time = time.time()
detections, predictions_dict, shapes = detect_fn_ckpt(input_tensor)
end_time = time.time()