Pytorch中类别不平衡的多标签分类

时间:2019-10-02 17:17:37

标签: pytorch multilabel-classification imbalanced-data

我有一个多标签分类问题,我正在尝试使用Pytorch中的CNN来解决。我有80,000个培训示例和7900个课程;每个示例可以同时属于多个类,每个示例的平均类数为130。

问题是我的数据集非常不平衡。对于某些类,我只有〜900个示例,大约是1%。对于“代表过多”的类,我有大约12000个示例(15%)。训练模型时,我使用来自pytorch的BCEWithLogitsLoss和正的权重参数。我按照文档中所述的相同方法计算权重:否定样本数除以肯定数。

结果是,我的模型高估了几乎每个班级……小班和大班班的预测几乎是真实标签的两倍。而我的AUPRC仅为0.18。即使比根本没有加权要好得多,因为在这种情况下,该模型会将所有内容预测为零。

所以我的问题是,如何提高性能?我还能做些什么吗?我尝试了不同的批量采样技术(以过度采样少数族裔),但是它们似乎没有用。

2 个答案:

答案 0 :(得分:1)

我建议使用其中一种策略

失踪

引入了一种非常有趣的方法,通过调整损失函数来处理不平衡的训练数据
林宗义,Priya Goyal,Ross Girshick,He Kaiming He和Piotr Dollar Focal Loss for Dense Object Detection(ICCV 2017)。
他们建议修改二进制交叉熵损失,以减少容易分类的示例的损失和梯度,同时将精力“集中于”模型会产生严重错误的示例。

硬负开采

另一种流行的方法是进行“强负面挖矿”;也就是说,仅针对部分训练示例(即“硬”示例)传播梯度。
参见,例如:
Abhinav Shrivastava,Abhinav Gupta和Ross Girshick Training Region-based Object Detectors with Online Hard Example Mining(CVPR 2016)

答案 1 :(得分:0)

@Shai提供了在深度学习时代开发的两种策略。我想为您提供其他一些传统的机器学习选项:过采样欠采样

它们的主要思想是在开始训练之前通过采样来产生更平衡的数据集。请注意,您可能会遇到一些问题,例如丢失数据多样性(欠采样)和过度拟合训练数据(过采样),但这可能是一个很好的起点。

有关更多信息,请参见wiki link