使用TensorFlow Keras构建分类器时,通常会在编译步骤中通过指定metrics=['accuracy']
来监控模型的准确性:
model = tf.keras.Model(...)
model.compile(optimizer=..., loss=..., metrics=['accuracy'])
无论模型是否输出对数或类概率,以及模型是否期望地面真值标签为单热编码矢量或整数索引(即,区间[0, n_classes)
中的整数),这都可以正常运行)。
如果要使用交叉熵损失,则不是这种情况:上述情况的四种组合中的每一种都需要在编译步骤中传递不同的loss
值:
如果模型输出概率,并且地面真相标签被一键编码,那么loss='categorical_crossentropy'
可以工作。
如果模型输出概率,并且地面真实标签是整数索引,那么loss='sparse_categorical_crossentropy'
可以工作。
如果模型输出对数,并且地面真相标签是一次热编码的,那么loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True)
可以工作。
如果模型输出对数,并且地面真值标签是整数索引,则loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
有效。
似乎仅指定loss='categorical_crossentropy'
不足以处理这四种情况,而指定metrics=['accuracy']
则足够健壮。
问题
当用户在模型编译步骤中指定metrics=['accuracy']
时,幕后发生了什么,无论模型输出对数还是概率,以及地面真相标签是否为一维编码,都可以正确执行精度计算向量还是整数索引?
我怀疑对数与概率的情况很简单,因为可以通过argmax的任意一种方式获取预测的类,但是理想情况下,我想指出TensorFlow 2源代码中实际进行计算的地方。
请注意,我当前正在使用TensorFlow 2.0.0-rc1。
修改
在纯Keras中,metrics=['accuracy']
在Model.compile
method中被显式处理。
答案 0 :(得分:2)
找到了它:这在tensorflow.python.keras.engine.training_utils.get_metric_function
中进行了处理。特别是,检查输出形状以确定要使用哪个精度函数。
要详细说明,在当前实现中,Model.compile
要么将度量处理委托给Model._compile_eagerly
(如果急切执行),要么直接执行。无论哪种情况,都将调用Model._cache_output_metric_attributes
,这将同时对未加权和加权指标都调用collect_per_output_metric_info
。此函数遍历提供的指标,在每个指标上调用get_metric_function
。