如何在Detectron2中设置自定义类权重

时间:2020-10-20 17:41:34

标签: python machine-learning computer-vision pytorch

我将https://detectron2.readthedocs.io/tutorials/install.html用于其他类和对象的数据集。

我的数据集不平衡。我希望为每个班级设置不同的权重。我该怎么办?

1 个答案:

答案 0 :(得分:0)

不幸的是,如果没有 writing your own components,还没有办法进行配置。

一种快速执行此操作的方法是编写一个新的头部,该头部继承自包含损失的头部。然后,损失将替换为使用您的损失权重初始化的新损失对象。

例如在 DeepLabV3+ 的情况下,这看起来像这样:

import torch
from torch import nn
from detectron2.modeling import SEM_SEG_HEADS_REGISTRY
from detectron2.projects.deeplab import DeepLabCE, DeepLabV3PlusHead

@SEM_SEG_HEADS_REGISTRY.register()
class MyNewHead(DeepLabV3PlusHead):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        weight = torch.Tensor([0.4, 0.6])  # Adapt to your case
        if self.loss_type == "cross_entropy":
            self.loss = nn.CrossEntropyLoss(
                reduction="mean", ignore_index=self.ignore_value, weight=weight
            )
        elif self.loss_type == "hard_pixel_mining":
            self.loss = DeepLabCE(
                ignore_label=self.ignore_value,
                top_k_percent_pixels=0.2,
                weight=weight,
            )
        else:
            raise ValueError("Unexpected loss type: %s" % self.loss_type)

然后,您修改配置文件以选择新头:

MODEL:  
  SEM_SEG_HEAD:
    NAME: "MyNewHead"