Tensorflow:具有交叉熵损失的加权稀疏softmax

时间:2017-07-20 14:44:17

标签: python tensorflow deep-learning softmax cross-entropy

我正在使用完全卷积神经网络进行图像分割(链接到论文):https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_fcn.pdf

这可以被视为像素分类(最后每个像素都得到一个标签)

我正在使用tf.nn.sparse_softmax_cross_entropy_with_logits损失函数。

loss = tf.reduce_mean((tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits,
                                                                      labels=tf.squeeze(annotation, squeeze_dims=[3]),
                                                                      name="entropy"))) 

一切进展顺利。但是,我看到一个类出现在绝大多数像素中(95%+),调用此类0.让我们说我们还有另外三个类,1,2和3。

将权重放入课程的最简单方法是什么?从本质上讲,我希望0级(如0.1)的体重非常低,而其他三个级别的体重正常为1。

我知道此功能存在:https://www.tensorflow.org/api_docs/python/tf/losses/sparse_softmax_cross_entropy

它只是在我看来它做了一些完全不同的事情,我不明白权重应该如何与标签具有相同的等级。我的意思是,在我的情况下,权重应该类似于Tensor([0.1,1,1,1]),因此形状(4,)和等级1,而标签具有形状(batch_size,width,height),因此等级3。我错过了什么吗?

PyTorch上的等价物将是

torch.nn.CrossEntropyLoss(weight=None, size_average=True, ignore_index=-100)

其中权重是火炬张量[0.1,1,1,1]

谢谢!

1 个答案:

答案 0 :(得分:1)

您的猜测是正确的,tf.losses.softmax_cross_entropy中的weights参数和tf.losses.sparse_softmax_cross_entropy表示批次的权重,即输入示例< / em>比其他人更重要。没有开箱即用的方法可以减轻的损失。

作为解决方法,您可以根据当前标签专门选择权重并将其用作批量权重。这意味着每个批次的权重向量将不同,但会尝试使偶尔的稀有类更重要。请参阅this question中的示例代码。

注意:由于批次不一定包含统一的类分布,因此这种技巧对于小批量大小效果不佳,并且随着批量大小的增加而变得更好。当批量大小为1时,它完全没用。所以让批次尽可能大。