不平衡数据和加权交叉熵

时间:2017-06-15 06:51:28

标签: python machine-learning tensorflow deep-learning

我正在尝试用不平衡的数据训练网络。我有A(198个样本),B个(436个样本),C个(710个样本),D个(272个样本),我读过“weighted_cross_entropy_with_logits”但我发现的所有例子都是二进制分类所以我不是很对如何设置这些重量充满信心。

总样本:1616

A_weight:198/1616 = 0.12?

如果我明白的话,背后的想法是惩罚了少数民族的错误,并且更加积极地评价少数民族的命中,对吗?

我的代码:

weights = tf.constant([0.12, 0.26, 0.43, 0.17])
cost = tf.reduce_mean(tf.nn.weighted_cross_entropy_with_logits(logits=pred, targets=y, pos_weight=weights))

我已阅读this one和其他二进制分类示例,但仍不太清楚。

提前致谢。

3 个答案:

答案 0 :(得分:61)

请注意,weighted_cross_entropy_with_logitssigmoid_cross_entropy_with_logits的加权变体。 Sigmoid交叉熵通常用于二进制分类。是的,它可以处理多个标签,但是S形交叉熵基本上对每个标签做出(二元)决策 - 例如,对于面部识别网,那些(不是互斥的)标签可以是" 受试者是否戴眼镜?"," 主题是女性吗?"等。

在二进制分类中,每个输出通道对应于二进制(软)决策。因此,加权需要在损失的计算中发生。这是weighted_cross_entropy_with_logits通过对交叉熵的一个项进行加权而对另一个进行加权的原因。

在互斥的多标签分类中,我们使用softmax_cross_entropy_with_logits,其行为不同:每个输出通道对应于候选类别的分数。通过比较每个通道的相应输出,之后做出决定。

因此,在最终决定之前加权是一个简单的问题,在比较之前修改得分,通常是通过乘以权重。例如,对于三元分类任务,

# your class weights
class_weights = tf.constant([[1.0, 2.0, 3.0]])
# deduce weights for batch samples based on their true label
weights = tf.reduce_sum(class_weights * onehot_labels, axis=1)
# compute your (unweighted) softmax cross entropy loss
unweighted_losses = tf.nn.softmax_cross_entropy_with_logits(onehot_labels, logits)
# apply the weights, relying on broadcasting of the multiplication
weighted_losses = unweighted_losses * weights
# reduce the result to get your final loss
loss = tf.reduce_mean(weighted_losses)

您还可以依靠tf.losses.softmax_cross_entropy来处理最后三个步骤。

在您需要解决数据不平衡的情况下,类权重确实与其列车数据中的频率成反比。将它们归一化以使它们总计为一个或多个类也是有意义的。

请注意,在上文中,我们根据样本的真实标签惩罚了损失。我们还可以通过简单定义

来基于估计的标签惩罚损失
weights = class_weights

由于广播魔术,其余代码无需更改。

在一般情况下,您需要权重取决于您所犯的错误类型。换句话说,对于每对标签XY,您可以选择在真实标签为X时如何惩罚选择标签Y。您最终得到一个完整的先前权重矩阵,这导致weights以上是一个完整的(num_samples, num_classes)张量。这有点超出了你想要的范围,但是要知道只有你的权重张量的定义需要在上面的代码中改变,这可能是有用的。

答案 1 :(得分:3)

Tensorflow 2.0兼容答案:为了社区的利益,迁移了P-Gn的2.0答案中指定的代码。

# your class weights
class_weights = tf.compat.v2.constant([[1.0, 2.0, 3.0]])
# deduce weights for batch samples based on their true label
weights = tf.compat.v2.reduce_sum(class_weights * onehot_labels, axis=1)
# compute your (unweighted) softmax cross entropy loss
unweighted_losses = tf.compat.v2.nn.softmax_cross_entropy_with_logits(onehot_labels, logits)
# apply the weights, relying on broadcasting of the multiplication
weighted_losses = unweighted_losses * weights
# reduce the result to get your final loss
loss = tf.reduce_mean(weighted_losses)

有关将代码从Tensorflow版本1.x迁移到2.x的更多信息,请参阅此Migration Guide

答案 2 :(得分:-1)

请参见this answer,了解与sparse_softmax_cross_entropy一起使用的替代解决方案