如何将tensorflow.metrics.x与SummarySaverHook和Estimator结合使用?

时间:2019-04-15 16:42:29

标签: python tensorflow

我的问题与此issue有关。我使用自定义tf.estimator.Estimator,并希望查看几种不同指标的学习曲线。我使用tf.train.SummarySaverHooktf.train.LoggingTensorHook。例如,我想添加accuracy并在Tensorboad上进行查看。我执行以下操作:

acc_value, acc_op = tf.metrics.accuracy(labels=labels, predictions=preds)
tf.summary.scalar('metrics_accuracy', acc_op)

一切正常,但是一切正常,因为我使用了acc_op,它总是非零。另一方面,某些指标为其None返回op,并且使用它们的唯一方法是执行tf.summary.scalar('metrics_accuracy', acc_value)。这是issue中讨论的问题。 metrics.x值的第一个值始终为零,这就是训练期间始终打印的值。如何使用?

P.S .:对于其op没有价值的指标为dynamic_streaming_auc,并在here处讨论了该问题。不,我没有使用它,而是使用它的修改版本-自定义auc。

1 个答案:

答案 0 :(得分:0)

我想出了办法。这很棘手。可以通过修改SummarySaverHook以在update_ops方法中调用before_run来做到这一点。步骤如下。

1)在model_fn定义估算器中,将指标op添加到tf.GraphKeys.UPDATE_OPS并将其值添加到tf.summary.scalar

acc_value, acc_op = tf.metrics.accuracy(labels=labels, predictions=preds)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, acc_op)
tf.summary.scalar('accuracy', acc_value)

2)创建将调用所有UpdateOpsHook的{​​{1}}:

update_ops
  1. 将此钩子添加到训练class UpdateOpsHook(tf.train.SessionRunHook): """Hook to execute all `update_ops` from tf.GraphKeys.UPDATE_OPS before each run. One needs to call `update_ops` to see metric values during training.""" def __init__(self): # Get all update_ops for (streaming) metrics, which are added # into `tf.GraphKeys.UPDATE_OPS` during creation of the graph self._update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, scope='metrics') def begin(self): self._global_step_tensor = tf.train.get_global_step() if self._global_step_tensor is None: raise RuntimeError("Global step should be created to use UpdateOpsHook.") def before_run(self, run_context): # Run `update_ops` return tf.train.SessionRunArgs(fetches=self._update_ops) 中,不要忘记将标量添加到评估EstimatorSpec中:
EstimatorSpec