如何在Keras中使用CNN处理用于多标签分类的不平衡数据?

时间:2019-12-27 14:14:56

标签: python keras multilabel-classification imbalanced-data

我的数据集形状为(91149, 12)

我使用CNN在文本分类任务中训练分类器

我发现训练准确性:0.5923和测试准确性:0.5780

我的班级有9个标签,如下所示:

df['thematique'].value_counts()
Corporate                   42399
Economie collaborative      13272
Innovation                  11360
Filiale                      5990
Richesses Humaines           4445
Relation sociétaire          4363
Communication                4141
Produits et services         2594
Sites Internet et applis     2585

模型结构:

model = Sequential()
embedding_layer = Embedding(vocab_size, 300, weights=[embedding_matrix],   input_length=maxlen   ,   trainable=False)
model.add(embedding_layer)
model.add(Conv1D(128, 7, activation='relu'))
model.add(GlobalMaxPooling1D())
model.add(Dense(9, activation='sigmoid'))
model.compile(optimizer='Adam', loss='categorical_crossentropy', metrics= ['categorical_accuracy'])

我的多标签分类数据不平衡。我需要在Keras中使用CNN处理不平衡数据,以便进行多分类分类。

2 个答案:

答案 0 :(得分:1)

准确度可能会误导您的问题,因为类别失衡严重,我会使用 F1 得分。

对于损失,您可以使用focal loss,它是分类交叉熵的变体,其重点是最少代表类。您可以找到一个示例here,根据我的经验,使用NLP分类任务上很少的课程对它有很大帮助。

答案 1 :(得分:0)

我不确定您是否需要使用Keras本身而不是凭直觉来解决不平衡问题。一种简单的方法是每个类使用相同数量的数据。当然,这会引起另一个问题,即您过滤了很多样本​​。但是仍然是可以检查的事情。当然,当数据不平衡时,仅计算分类性能并不是一个好主意,因为它很好地完成了每个类的性能。

您还应该计算混淆矩阵,以可视化每个班级的表现。在此bloghere中可以找到解决不平衡数据问题的更详细方法。

最重要的是使用正确的工具来评估分类的性能,并按照我提到的链接中的建议处理输入数据。