使用tf.estimator.EstimatorSpec时,如何在每个时期后检查评估auc?

时间:2019-05-16 10:52:24

标签: tensorflow machine-learning deep-learning

我使用tf.estimator.EstimatorSpec定义了我的模型。我知道它具有训练,评估和预测模式。但是我想在每个时期后检查一些指标得分,例如auc。该API是否像keras一样支持它?

1 个答案:

答案 0 :(得分:0)

没有直接的API用于添加类似AUC的度量标准,但是您可以使用Custom Metric Function创建tf.keras.metrics,然后使用tf.estimator.add_metrics在Estimator中使用这些度量标准。

示例代码展示了AUC的实现,如下所示:

  def my_auc(labels, predictions):
    auc_metric = tf.keras.metrics.AUC(name="my_auc")
    auc_metric.update_state(y_true=labels, y_pred=predictions['logistic'])
    return {'auc': auc_metric}

  estimator = tf.estimator.DNNClassifier(...)
  estimator = tf.estimator.add_metrics(estimator, my_auc)
  estimator.train(...)
  estimator.evaluate(...)

  def my_auc(labels, predictions, features):
    auc_metric = tf.keras.metrics.AUC(name="my_auc")
    auc_metric.update_state(y_true=labels, y_pred=predictions['logistic'],
                            sample_weight=features['weight'])
    return {'auc': auc_metric}

  estimator = tf.estimator.DNNClassifier(...)
  estimator = tf.estimator.add_metrics(estimator, my_auc)
  estimator.train(...)
  estimator.evaluate(...)