它在keras代码categorical_crossentropy中缩放了两次吗?

时间:2019-07-01 01:45:49

标签: tensorflow keras

我看到categorical_crossentropy在Keras中实现如下:

def categorical_crossentropy(target, output, from_logits=False, axis=-1):
    """Categorical crossentropy between an output tensor and a target tensor.
    # Arguments
        target: A tensor of the same shape as `output`.
        output: A tensor resulting from a softmax
            (unless `from_logits` is True, in which
            case `output` is expected to be the logits).
        from_logits: Boolean, whether `output` is the
            result of a softmax, or is a tensor of logits.
        axis: Int specifying the channels axis. `axis=-1`
            corresponds to data format `channels_last`,
            and `axis=1` corresponds to data format
            `channels_first`.
    # Returns
        Output tensor.
    # Raises
        ValueError: if `axis` is neither -1 nor one of
            the axes of `output`.
    """
    output_dimensions = list(range(len(output.get_shape())))
    if axis != -1 and axis not in output_dimensions:
        raise ValueError(
            '{}{}{}'.format(
                'Unexpected channels axis {}. '.format(axis),
                'Expected to be -1 or one of the axes of `output`, ',
                'which has {} dimensions.'.format(len(output.get_shape()))))
    # Note: tf.nn.softmax_cross_entropy_with_logits
    # expects logits, Keras expects probabilities.
    if not from_logits:
        # scale preds so that the class probas of each sample sum to 1
        output /= tf.reduce_sum(output, axis, True)
        # manual computation of crossentropy
        _epsilon = _to_tensor(epsilon(), output.dtype.base_dtype)
        output = tf.clip_by_value(output, _epsilon, 1. - _epsilon)
        return - tf.reduce_sum(target * tf.log(output), axis)

我不在

的支持下
  

output_dimensions = list(range(len(len(output.get_shape())))

  

输出/ = tf.reduce_sum(输出,轴,真)。

我理解输出是概率,是由softmax产生的张量->它的意思是按比例缩放的preds,以使每个样本和的probas类为1。为什么它们需要按比例缩放的pres,以使每个样本和的probas类。再次为1?请解释一下。

1 个答案:

答案 0 :(得分:0)

因为您需要确保每个概率都在0到1之间,否则交叉熵的计算将是不正确的。这也是一种防止用户在超出范围内(未归一化)概率时出错的方法。