为什么`K.cast`用于`categorical_accuracy`而不是`K.mean`?

时间:2017-09-13 09:02:04

标签: python keras

函数keras.metrics.binary_accuracy非常简单:

def binary_accuracy(y_true, y_pred):
    return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)

https://github.com/fchollet/keras/blob/master/keras/metrics.py#L20

然而,函数keras.metrics.categorical_accuracy有不同之处:

def categorical_accuracy(y_true, y_pred):
    return K.cast(K.equal(K.argmax(y_true, axis=-1),
                          K.argmax(y_pred, axis=-1)),
                  K.floatx())

https://github.com/fchollet/keras/blob/master/keras/metrics.py#L24

我很困惑为什么这个函数使用K.cast而不是K.mean?因为我认为这个函数应该返回一个数字,就像函数keras.metrics.binary_accuracy

一样

1 个答案:

答案 0 :(得分:5)

cast的原因是因为argmax返回一个整数,它是最高值的索引。但结果必须是浮动的。

argmax功能:

argmax函数也会降低输入的等级。请注意它使用axis=-1,这意味着它将采用最后一个轴的最大值索引,消除该轴,但保留其他轴。

假设您的输入形状为(10 samples, 5 features),则返回的张量将仅为(10 samples,)

mean功能axis=-1

通常情况下,mean函数会返回一个标量,但如果仔细查看binary_accuracy,您还会注意到,axis=-1函数中使用了mean,它不会将输入减少到单个标量值。它以与argmax完全相同的方式减小张量,但在这种情况下,计算平均值。

输入(10,5)也会出现(10,)

最终结果:

因此,我们可以得出结论,两个指标都返回具有相同形状的张量。现在,他们都没有减少标量值的原因是因为Keras提供了更多的可能性,例如样本加权和一些其他额外的操作与损失(包括你自己的自定义损失,如果你想把Keras'作为一个基础)。这些将依赖于样本的损失。

稍后某处,Keras将计算最终平均值。