解释keras代码categorical_crossentropy?

时间:2019-06-27 14:14:15

标签: tensorflow keras

这是绝对交叉熵的公式:Formula of categorical cross-entropy 我理解输出是来自softmax层的类的概率。那正确吗?目标是什么?以及该代码如何显示“ 1 / N”,“Σ”,“ p i,j ”?

def categorical_crossentropy(output, target, from_logits=False):
"""Categorical crossentropy between an output tensor and a target tensor.
# Arguments
    output: A tensor resulting from a softmax
        (unless `from_logits` is True, in which
        case `output` is expected to be the logits).
    target: A tensor of the same shape as `output`.
    from_logits: Boolean, whether `output` is the
        result of a softmax, or is a tensor of logits.
# Returns
    Output tensor.
"""
# Note: tf.nn.softmax_cross_entropy_with_logits
# expects logits, Keras expects probabilities.
if not from_logits:
    # scale preds so that the class probas of each sample sum to 1
    output /= tf.reduce_sum(output,
                            reduction_indices=len(output.get_shape()) - 1,
                            keep_dims=True)
    # manual computation of crossentropy
    epsilon = _to_tensor(_EPSILON, output.dtype.base_dtype)
    output = tf.clip_by_value(output, epsilon, 1. - epsilon)
    return - tf.reduce_sum(target * tf.log(output),
                           reduction_indices=len(output.get_shape()) - 1)
else:
    return tf.nn.softmax_cross_entropy_with_logits(labels=target,
                                                   logits=output)

1 个答案:

答案 0 :(得分:0)

  

我理解输出是来自softmax层的类的概率。那正确吗?

它可以是softmax层的输出,也可以是原始logits(输入到softmax层的输入)。 softmax层的输出向量是每个类别的概率。如果output是softmax的输出,则设置from_logits=False。如果output是登录名,那么您要设置from_logits=True。您可以在内部看到tf.nn.softmax_cross_entropy_with_logits被调用,它同时计算softmax概率和交叉熵函数。将它们一起计算可以为提高数值稳定性提供一些数学技巧。

  

目标是什么?

目标是一个热点。这意味着数字n由向量v表示,其中v[n] = 10在其他任何地方。 n是标签的类别。在TensoFlow中有一个名为tf.one_hot的函数可以获取这种编码。例如,tf.one_hot([3],5)将产生向量[0, 0, 1, 0, 0]

  

该代码如何显示“ 1 / N”,“Σ”,“ pi,j”?

上面的代码未对所有输入求平均值(不需要“ 1 / N”)。例如,如果输入的形状为[10, 5],则输出的形状为[10]。您将必须对结果调用tf.reduce_mean。因此,等式实质上是:

modified equation

上面的等式在一行中实现

return - tf.reduce_sum(target * tf.log(output),
                       reduction_indices=len(output.get_shape()) - 1)

“Σ”为tf.reduce_sum。 “ pi,j”是output,指示符函数(即粗体1)是一键编码的target

旁注

您应该使用tf.softmax_cross_entropy_with_logits_v2,因为您提供的代码(设置from_logits=False时)可能会导致数字错误。组合函数可以解决所有这些数字问题。