使用tf.metrics.auc进行培训和验证

时间:2018-02-23 16:55:39

标签: tensorflow

在使用相同的模型进行培训和验证时,是否有正确的方法来使用tf.metrics.auc?我目前将我的模型作为一个类实现,并使用tf.data字符串句柄在训练和验证数据之间进行交换。因为tf.metrics.auc函数创建局部变量,我需要采取其他步骤来确保它们不在培训和验证之间共享吗?

最终我想将ROC和PR值下的区域提供给tf.summary.scalar,我可以在这里看到训练和验证,所以这使得使用{{1}进一步复杂化将图分割成Tensorboard中的单独部分。

1 个答案:

答案 0 :(得分:1)

我设法通过创建一个reset_op来解决这个问题,我曾经在执行摘要步骤之前将局部变量清空

eval_reset_ops = []
for scope in ['outcome_' + str(i) for i in range(cfg.num_class)]:
    for var in tf.local_variables(scope=scope):
        eval_reset_ops.append(tf.assign(var, tf.zeros_like(var)))

然后我运行以下部分来制作摘要。

_ = sess.run(eval_reset_ops)
for _ in range(n_summary_batches):
    eval_metrics, tr_losses = sess.run([model.eval_ops, model.losses])
summary = sess.run(merged)
train_writer.add_summary(summary, step)
train_writer.flush()

在这种情况下,model.eval_opstf.metrics.auc生成的更新操作列表。

如......

auc, auc_op = tf.metrics.auc(
    labels=labels, predictions=prob, num_thresholds=1000,
   updates_collections='eval_updates', curve='ROC')
aupr, pr_op = tf.metrics.auc(
    labels=labels, predictions=prob, num_thresholds=1000,
    updates_collections='eval_updates', curve='PR')
eval_ops = [auc_op, pr_op]

这似乎工作正常。我希望有人认为这很有用,或者可以指出如何改进它。