谁能解释一下eager_few_shot_od_training_tflite.ipynb 代码

时间:2021-02-26 07:05:04

标签: image tensorflow object-detection tensorflow-lite

我正在尝试理解 tensorflow 团队在 github 链接 eager_few_shot_od_training_tflite.ipynb 中提供的示例代码。除了下面几行代码,我能理解大部分代码。

  1. 不确定我们为什么要创建 fake_box_predictor 变量并创建 fake_model
  2. 为什么我们不能直接在 detection_model 上恢复检查点而不是创建 fake_model
  3. 我也不明白上面代码的注释“我们将恢复框回归头,但从头开始初始化分类头”。

谁能详细解释下面的代码在做什么(以及这段代码上面提供的注释是什么意思)?

# Set up object-based checkpoint restore --- SSD has two prediction
# `heads` --- one for classification, the other for box regression.  We will
# restore the box regression head but initialize the classification head
# from scratch (we show the omission below by commenting out the line that
# we would add if we wanted to restore both heads)
fake_box_predictor = tf.compat.v2.train.Checkpoint(
    _base_tower_layers_for_heads=detection_model._box_predictor._base_tower_layers_for_heads,
    # _prediction_heads=detection_model._box_predictor._prediction_heads,
    #    (i.e., the classification head that we *will not* restore)
    _box_prediction_head=detection_model._box_predictor._box_prediction_head,
    )
fake_model = tf.compat.v2.train.Checkpoint(
          _feature_extractor=detection_model._feature_extractor,
          _box_predictor=fake_box_predictor)
ckpt = tf.compat.v2.train.Checkpoint(model=fake_model)
ckpt.restore(checkpoint_path).expect_partial()

1 个答案:

答案 0 :(得分:1)

有关此主题的详细信息和文档非常稀少。我对此的理解是,

2. 这里为一些层加载了预训练检查点的权重以进行迁移学习,其余的则从头开始初始化以学习新类。检测模型是根据给定的配置构建的。如果 checkpointdetection_model 上恢复,那么它将检测来自 COCO 数据集的类,因为 TF2 detections model zoo 预训练模型是在 COCO 数据集上训练的。

3. 目标是对新类别的图像进行分类。这与类预测层相关联。所以这些层应该从头开始初始化,其余的特征提取层,边界框回归层从检查点恢复,这样模型就可以利用这些预训练的权重。这将有助于模型更快地收敛并检测新的所需类别。

1. 这是加载部分模型的过程。类预测和框回归头之前的 base_tower_layers_for_heads contains the earlier layers

_box_prediction_head 预测边界框。如果使用 _prediction_heads=detection_model._box_predictor._prediction_heads,它应该恢复分类和回归头,因为它contains both box_prediction_headsclass_prediction_heads

detection_model._feature_extractorlikely the initial layers containing 分类网络,例如用于特征提取的 Resnet、Mobilenet + FPN。

fake_model 将 bbox 回归头 + 其较早的层与计算图中的基本特征提取层(例如 mobilenet)连接起来。 vars(fake_model) 将包含 _box_predictor_feature_extractor

然后从预训练的检查点恢复所需的权重,expect_partial() 用于 silence warning 检查点未使用的部分。

coursera 上的 Tensorflow 高级计算机视觉课程有更多关于这方面的细节以及整体的 tensorflow 对象检测结构和代码。欢迎纠正任何错误。