如何使用pytorch处理多标签分类中的类不平衡

时间:2020-05-18 21:53:04

标签: machine-learning pytorch multilabel-classification cnn

我们正在尝试在pytorch中使用CNN进行多标签分类。我们将90/10分割为训练/验证集使用8个标签和大约260张图像。

班级高度不平衡,最频繁的班级出现在140多个图像中。另一方面,最不频繁的类别出现在少于5张图像中。

我们最初尝试使用BCEWithLogitsLoss函数,该函数导致模型针对所有图像预测相同的标签。

然后,我们实施了一种损失集中的方法来处理班级失衡,如下所示:

    import torch.nn as nn
    import torch

    class FocalLoss(nn.Module):
        def __init__(self, alpha=1, gamma=2):
            super(FocalLoss, self).__init__()
            self.alpha = alpha
            self.gamma = gamma

        def forward(self, outputs, targets):
            bce_criterion = nn.BCEWithLogitsLoss()
            bce_loss = bce_criterion(outputs, targets)
            pt = torch.exp(-bce_loss)
            focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
            return focal_loss 

这导致该模型预测每个图像的空集(没有标签),因为对于任何类别都无法获得大于0.5的置信度。

pytorch中是否有一种方法可以解决这种情况?

1 个答案:

答案 0 :(得分:1)

基本上有三种处理方式。

  1. 丢弃更常见类中的数据

  2. 少数族裔的体重减轻值

  3. 对少数群体过度采样

选项1通过选择数据集中包含的文件来实现。

选项2通过pos_weight的{​​{1}}参数实现

选项3通过将自定义BCEWithLogitsLoss传递到您的数据加载器来实现

对于深度学习,过采样通常最有效。

相关问题