计算TensorFlow中多个类的平均和类精度/召回率

时间:2019-10-02 13:08:51

标签: tensorflow keras

我有一个包含4个类的多类模型。我已经实现了一个回调,该回调能够计算每个类及其宏平均值的精度/调用率。但是出于某种技术原因,我必须使用指标机制来计算它们。

我正在使用TensorFlow 2和Keras 2.3.0。我已经使用tensorflow.keras.metrics.Recall/Precision来获取类指标:

metrics_list = ['accuracy']
    metrics_list.extend([Recall(class_id=i, name="recall_{}".format(label_names[i])) for i in range(n_category)])
    metrics_list.extend([Precision(class_id=i, name="precision_{}".format(label_names[i])) for i in range(n_category)])

model = Model(...)
model.compile(...metrics=metrics_list)

但是,这种解决方案并不令人满意:

  1. 首先,tensorflow.keras.metrics.Recall/Precision使用阈值定义与某个类的隶属关系,而如果定义了argmax,则应使用class_id定义最可能的类

  2. 第二,我必须创建2个新指标来计算所有类的平均值,而这本身就需要计算类指标。计算同一件事的两倍是不礼貌且效率低下的。

是否有一种方法可以创建一个使用TensorFlow / Keras度量逻辑直接计算类和平均谓词/调用的类或函数?

显然,我可以使用tf.math.confusion_matrix()轻松获得混淆矩阵。但是,我看不到如何立即注入一个标量列表,而不是返回单个标量。

欢迎任何评论!

1 个答案:

答案 0 :(得分:0)

在我的特定情况下,由于我使用的是CategoricalAccuracy(),因此我可以简单地将batch_size=1用作唯一指标。在这种情况下,accuracy=recall=precision={1.|0.}为批次。那只能部分解决问题。最好的解决方案是在每个批次的末尾使用argmax更新混淆矩阵,然后根据该值计算Precision / Recall。我还不知道怎么做,但是应该可行。