我们正在尝试在pytorch中使用CNN进行多标签分类。我们将90/10分割为训练/验证集使用8个标签和大约260张图像。
班级高度不平衡,最频繁的班级出现在140多个图像中。另一方面,最不频繁的类别出现在少于5张图像中。
我们最初尝试使用BCEWithLogitsLoss函数,该函数导致模型针对所有图像预测相同的标签。
然后,我们实施了一种损失集中的方法来处理班级失衡,如下所示:
import torch.nn as nn
import torch
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, outputs, targets):
bce_criterion = nn.BCEWithLogitsLoss()
bce_loss = bce_criterion(outputs, targets)
pt = torch.exp(-bce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
return focal_loss
这导致该模型预测每个图像的空集(没有标签),因为对于任何类别都无法获得大于0.5的置信度。
pytorch中是否有一种方法可以解决这种情况?
答案 0 :(得分:1)
基本上有三种处理方式。
丢弃更常见类中的数据
少数族裔的体重减轻值
对少数群体过度采样
选项1通过选择数据集中包含的文件来实现。
选项2通过pos_weight
的{{1}}参数实现
选项3通过将自定义BCEWithLogitsLoss
传递到您的数据加载器来实现
对于深度学习,过采样通常最有效。