多标签分类丢失功能

时间:2018-05-17 15:25:10

标签: python tensorflow neural-network classification multilabel-classification

我在许多地方已经看到,对于使用神经网络的多标签分类,一个有用的损失函数是每个输出节点的二进制交叉熵。

在Tensorflow中,它看起来像这样:

cost = tf.nn.sigmoid_cross_entropy_with_logits()

这为数组提供了与我们拥有的输出节点一样多的值。

我的问题是,这个成本函数是否应该在输出节点数上取平均值?在Tensorflow中看起来像:

cost = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits())

或者每项损失都是独立处理的?

谢谢

1 个答案:

答案 0 :(得分:3)

对于多标签分类中的N标签,您是否对每个类的损失求和,或者是否使用tf.reduce_mean计算平均损失并不重要:渐变会指向同一个方向。

但是,如果您将总和除以N(这是平均值基本上是这个),这将影响一天结束时的学习率。如果您不确定多标签分类任务中有多少个标签,则可能更容易使用tf.reduce_mean,因为您不必重新调整此损失组件的重量。损失的组成部分,您不必调整标签更改N的学习率。