我正在尝试理解 tensorflow 团队在 github 链接 eager_few_shot_od_training_tflite.ipynb 中提供的示例代码。除了下面几行代码,我能理解大部分代码。
fake_box_predictor
变量并创建 fake_model
?detection_model
上恢复检查点而不是创建 fake_model
?谁能详细解释下面的代码在做什么(以及这段代码上面提供的注释是什么意思)?
# Set up object-based checkpoint restore --- SSD has two prediction
# `heads` --- one for classification, the other for box regression. We will
# restore the box regression head but initialize the classification head
# from scratch (we show the omission below by commenting out the line that
# we would add if we wanted to restore both heads)
fake_box_predictor = tf.compat.v2.train.Checkpoint(
_base_tower_layers_for_heads=detection_model._box_predictor._base_tower_layers_for_heads,
# _prediction_heads=detection_model._box_predictor._prediction_heads,
# (i.e., the classification head that we *will not* restore)
_box_prediction_head=detection_model._box_predictor._box_prediction_head,
)
fake_model = tf.compat.v2.train.Checkpoint(
_feature_extractor=detection_model._feature_extractor,
_box_predictor=fake_box_predictor)
ckpt = tf.compat.v2.train.Checkpoint(model=fake_model)
ckpt.restore(checkpoint_path).expect_partial()
答案 0 :(得分:1)
有关此主题的详细信息和文档非常稀少。我对此的理解是,
2. 这里为一些层加载了预训练检查点的权重以进行迁移学习,其余的则从头开始初始化以学习新类。检测模型是根据给定的配置构建的。如果 checkpoint 在 detection_model
上恢复,那么它将检测来自 COCO
数据集的类,因为 TF2 detections model zoo 预训练模型是在 COCO 数据集上训练的。
3. 目标是对新类别的图像进行分类。这与类预测层相关联。所以这些层应该从头开始初始化,其余的特征提取层,边界框回归层从检查点恢复,这样模型就可以利用这些预训练的权重。这将有助于模型更快地收敛并检测新的所需类别。
1. 这是加载部分模型的过程。类预测和框回归头之前的 base_tower_layers_for_heads
contains the earlier layers。
_box_prediction_head
预测边界框。如果使用 _prediction_heads=detection_model._box_predictor._prediction_heads
,它应该恢复分类和回归头,因为它contains both box_prediction_heads
和 class_prediction_heads
。
detection_model._feature_extractor
是 likely the initial layers containing 分类网络,例如用于特征提取的 Resnet、Mobilenet + FPN。
fake_model
将 bbox 回归头 + 其较早的层与计算图中的基本特征提取层(例如 mobilenet)连接起来。 vars(fake_model)
将包含 _box_predictor
和 _feature_extractor
。
然后从预训练的检查点恢复所需的权重,expect_partial()
用于 silence warning 检查点未使用的部分。
coursera 上的 Tensorflow 高级计算机视觉课程有更多关于这方面的细节以及整体的 tensorflow 对象检测结构和代码。欢迎纠正任何错误。