如何使用sparse_softmax_cross_entropy_with_logits在tensorflow中实现加权交叉熵损失

时间:2016-10-23 00:20:39

标签: python tensorflow deep-learning caffe cross-entropy

我开始使用tensorflow(来自Caffe),我正在使用损失sparse_softmax_cross_entropy_with_logits。该函数接受0,1,...C-1之类的标签而不是onehot编码。现在,我想根据类标签使用加权;我知道如果我使用softmax_cross_entropy_with_logits(一个热编码),可以使用矩阵乘法来完成这个,有没有办法对sparse_softmax_cross_entropy_with_logits做同样的事情?

3 个答案:

答案 0 :(得分:16)

import  tensorflow as tf
import numpy as np

np.random.seed(123)
sess = tf.InteractiveSession()

# let's say we have the logits and labels of a batch of size 6 with 5 classes
logits = tf.constant(np.random.randint(0, 10, 30).reshape(6, 5), dtype=tf.float32)
labels = tf.constant(np.random.randint(0, 5, 6), dtype=tf.int32)

# specify some class weightings
class_weights = tf.constant([0.3, 0.1, 0.2, 0.3, 0.1])

# specify the weights for each sample in the batch (without having to compute the onehot label matrix)
weights = tf.gather(class_weights, labels)

# compute the loss
tf.losses.sparse_softmax_cross_entropy(labels, logits, weights).eval()

答案 1 :(得分:2)

特别是对于二元分类,有weighted_cross_entropy_with_logits,它计算加权softmax交叉熵。

sparse_softmax_cross_entropy_with_logits是高效非加权操作的尾部(请参阅SparseSoftmaxXentWithLogitsOp,其中使用SparseXentEigenImpl),因此它不是"可插拔& #34;

在多类的情况下,您可以选择切换到单热编码或以hacky方式使用tf.losses.sparse_softmax_cross_entropy丢失功能,如已建议的那样,您必须根据标签中的标签传递权重当前批次。

答案 2 :(得分:1)

类权重乘以logits,因此仍适用于sparse_softmax_cross_entropy_with_logits。有关Tensor流中类不平衡二进制分类器的"损失函数,请参阅this solution。"

作为旁注,您可以将权重直接传递到sparse_softmax_cross_entropy

tf.contrib.losses.sparse_softmax_cross_entropy(logits, labels, weight=1.0, scope=None)

此方法用于使用

进行交叉熵丢失
tf.nn.sparse_softmax_cross_entropy_with_logits.

重量作为损失的系数。如果提供了标量,那么损失将简单地按给定值进行缩放。如果权重是一个大小的张量[batch_size],那么损失权重适用于每个相应的样本。