在tf.Estimator设置中使用tf.metrics.precision / recall计算F1分数

时间:2018-12-04 20:01:58

标签: python tensorflow tensorflow-estimator

我正在尝试在tf.Estimator设置中计算F1分数。

我见过这个SO question,但无法从中提取出可行的解决方案。

tf.Estimator的用处在于,它希望我提供一个值和一个更新操作,因此,现在,我的模型末尾有这段代码:

if mode == tf.estimator.ModeKeys.EVAL:
    with tf.variable_scope('eval'):
        precision, precision_update_op = tf.metrics.precision(labels=labels,
                                            predictions=predictions['class'],
                                            name='precision')

        recall, recall_update_op = tf.metrics.recall(labels=labels,
                                      predictions=predictions['class'],
                                      name='recall')

        f1_score, f1_update_op = tf.metrics.mean((2 * precision * recall) / (precision + recall), name='f1_score')

        eval_metric_ops = {
            "precision": (precision, precision_update_op),
            "recall": (recall, recall_update_op),
            "f1_score": (f1_score, f1_update_op)}

现在精度和召回率似乎还不错,但是在F1分数上,我一直保持nan

我应该如何使它正常工作?

编辑:

tf.contrib.metrics.f1_score可以实现有效的解决方案,但是由于contrib将在TF 2.0中被弃用,因此,我希望减少contrib的解决方案

4 个答案:

答案 0 :(得分:1)

我是这样做的:

def f1_score_class0(labels, predictions):
    """
    To calculate f1-score for the 1st class.
    """
    prec, update_op1 = tf.compat.v1.metrics.precision_at_k(labels, predictions, 1, class_id=0)
    rec,  update_op2 = tf.compat.v1.metrics.recall_at_k(labels, predictions, 1, class_id=0)

    return {
            "f1_Score_for_class0":
                ( 2*(prec * rec) / (prec + rec) , tf.group(update_op1, update_op2) )
    }

答案 1 :(得分:0)

1)您为什么要tf.metrics.mean?查全率和精度是标量值

2)您是否尝试打印f1_scoref1_update_op

3)他们在documentation of recall中提到了

  

为了估计数据流上的度量,该函数创建了一个update_op,该update_op更新这些变量并返回调用。 update_op通过相应的权重值加权每个预测

由于您是直接从处理更新的两个操作中获取F1分数的,请尝试执行tf.identity(实际上不会导致更改)

答案 2 :(得分:0)

f1值张量可以从精度和查全率值计算。指标必须是(值,update_op)元组。我们可以将tf.identity传递给f1。这对我有用:

import tensorflow as tf

def metric_fn(labels, logits):
    predictions = tf.argmax(logits, axis=-1)
    pr, pr_op = tf.metrics.precision(labels, predictions)
    re, re_op = tf.metrics.recall(labels, predictions)
    f1 = (2 * pr * re) / (pr + re)
    return {
        'precision': (pr, pr_op),
        'recall': (re, re_op),
        'f1': (f1, tf.identity(f1))
    }

答案 3 :(得分:0)