如何处理多标签分类的不平衡数据集

时间:2017-06-02 13:49:02

标签: tensorflow deep-learning multilabel-classification

我想知道在处理真正不平衡的数据集时如何惩罚代表性较少的类而不是其他类(在大约20000个样本中有10个类,但这里是每个类的出现次数:[10868 26 4797 26 8320 26 5278 9412] 4485 16172])。

我读到了Tensorflow函数:weighted_cross_entropy_with_logits(https://www.tensorflow.org/api_docs/python/tf/nn/weighted_cross_entropy_with_logits)但我不确定是否可以将它用于多标签问题。

我找到了一篇文章,总结了我的问题(Neural Network for Imbalanced Multi-Class Multi-Label Classification)并提出了一个想法,但它没有答案,我认为这个想法可能会很好:)

感谢您的想法和答案!

3 个答案:

答案 0 :(得分:0)

首先,我的建议是您可以修改成本函数以便以多标签方式使用。有code显示如何在Tensorflow中使用Softmax Cross Entropy进行多标记图像任务。

使用该代码,您可以在每行损失计算中使用多个权重。以下是具有多标签任务的示例代码:(即,每个图像可以有两个标签)

logits_split  = tf.split( axis=1, num_or_size_splits=2, value= logits  ) 
labels_split  = tf.split( axis=1, num_or_size_splits=2, value= labels  )
weights_split = tf.split( axis=1, num_or_size_splits=2, value= weights )
total         = 0.0

for i in range ( len(logits_split) ):  
    temp   = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( logits=logits_split[i] , labels=labels_split[i] )) 
    total += temp * tf.reshape(weights_split[i],[-1])  

答案 1 :(得分:0)

我认为您可以使用tf.nn.weighted_cross_entropy_with_logits进行多类分类。

例如,对于4个类,其中具有最大成员数的类的比率为[0.8, 0.5, 0.6, 1],您只需按以下方式为其赋予权重向量:

cross_entropy = tf.nn.weighted_cross_entropy_with_logits(
        targets=ground_truth_input, logits=logits, 
        pos_weight = tf.constant([0.8,0.5,0.6,1]))

答案 2 :(得分:0)

因此,鉴于你所写的内容,我并不完全确定我理解你的问题。你链接到的帖子写了关于多标签和多类的文章,但是根据那里写的内容,这并没有真正意义。因此,我会将此作为一个多类问题来处理,对于每个样本,您只有一个标签。

为了惩罚课程,我根据当前批次中的标签实施了一个权重张量。对于3级问题,您可以例如。将权重定义为类的反向频率,使得如果对于类1,2和3的比例分别为[0.1,0.7,0.2],则权重将为[10,1.43,5]。然后根据当前批次定义权重张量

weight_per_class = tf.constant([10, 1.43, 5]) # shape (, num_classes)
onehot_labels = tf.one_hot(labels, depth=3) # shape (batch_size, num_classes)
weights = tf.reduce_sum(
    tf.multiply(onehot_labels, weight_per_class), axis=1) # shape (batch_size, num_classes)
reduction = tf.losses.Reduction.MEAN # this ensures that we get a weighted mean
loss = tf.losses.softmax_cross_entropy(
        onehot_labels=onehot_labels, logits=logits, weights=weights, reduction=reduction)

使用softmax可确保分类问题不是3个独立的分类。