引擎盖实施下的TensorFlow Keras``准确性''指标

时间:2019-09-18 22:24:18

标签: python tensorflow machine-learning deep-learning tf.keras

使用TensorFlow Keras构建分类器时,通常会在编译步骤中通过指定metrics=['accuracy']来监控模型的准确性:

model = tf.keras.Model(...)
model.compile(optimizer=..., loss=..., metrics=['accuracy'])

无论模型是否输出对数或类概率,以及模型是否期望地面真值标签为单热编码矢量或整数索引(即,区间[0, n_classes)中的整数),这都可以正常运行)。

如果要使用交叉熵损失,则不是这种情况:上述情况的四种组合中的每一种都需要在编译步骤中传递不同的loss值:

  1. 如果模型输出概率,并且地面真相标签被一键编码,那么loss='categorical_crossentropy'可以工作。

  2. 如果模型输出概率,并且地面真实标签是整数索引,那么loss='sparse_categorical_crossentropy'可以工作。

  3. 如果模型输出对数,并且地面真相标签是一次热编码的,那么loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True)可以工作。

  4. 如果模型输出对数,并且地面真值标签是整数索引,则loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)有效。

似乎仅指定loss='categorical_crossentropy'不足以处理这四种情况,而指定metrics=['accuracy'] 则足够健壮。


问题 当用户在模型编译步骤中指定metrics=['accuracy']时,幕后发生了什么,无论模型输出对数还是概率,以及地面真相标签是否为一维编码,都可以正确执行精度计算向量还是整数索引?


我怀疑对数与概率的情况很简单,因为可以通过argmax的任意一种方式获取预测的类,但是理想情况下,我想指出TensorFlow 2源代码中实际进行计算的地方。

请注意,我当前正在使用TensorFlow 2.0.0-rc1


修改 在纯Keras中,metrics=['accuracy']Model.compile method中被显式处理。

1 个答案:

答案 0 :(得分:2)

找到了它:这在tensorflow.python.keras.engine.training_utils.get_metric_function中进行了处理。特别是,检查输出形状以确定要使用哪个精度函数。

要详细说明,在当前实现中,Model.compile要么将度量处理委托给Model._compile_eagerly(如果急切执行),要么直接执行。无论哪种情况,都将调用Model._cache_output_metric_attributes,这将同时对未加权和加权指标都调用collect_per_output_metric_info。此函数遍历提供的指标,在​​每个指标上调用get_metric_function