在使用带有Estimator API的自定义SessionRunHook时,“无法使用`eval()`评估张量:未注册默认会话”

时间:2018-08-04 12:32:48

标签: python tensorflow tensorflow-estimator

我关注this example,以学习如何使用Estimator API构建TensorFlow的CNN。在给定的示例中,一行pred_probas = tf.nn.softmax(logits_test)对于我来说很有价值,如果我能够获得这些概率,因为我想在我编写的这个小代码段中使用它们:

def eer_eval(y_true, probas):
    fpr, tpr, thresholds = roc_curve(y_true.eval(), probas[:, 1].eval())
    return brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)

阅读this post之后,我写了自己的钩子

class _EERHook(tf.train.SessionRunHook):
    def __init__(self, probas, labels):
        self.labels = labels
        self.probas = probas

    def begin(self):
        pass

    def before_run(self, run_context):
        return tf.train.SessionRunArgs(eer_eval(self.labels, self.probas))

    def after_run(self,
                run_context,  # pylint: disable=unused-argument
                run_values):
        eer = run_values.results
        print("EER: ", eer)

我想在模型评估期间使用

estim_specs = tf.estimator.EstimatorSpec(
        mode=mode,
        predictions=pred_classes,
        loss=loss_op,
        train_op=train_op,
        eval_metric_ops={'accuracy': acc_op},
        evaluation_hooks=[_EERHook(pred_probas, labels)])

但是,代码因错误而崩溃

ValueError: Cannot evaluate tensor using `eval()`: No default session is registered. Use `with sess.as_default()` or pass an explicit session to `eval(session=sess)`

有什么方法可以在评估过程中将这些概率保存到人类可读的csv文件中,或者可以使我的代码段正常工作吗?

1 个答案:

答案 0 :(得分:1)

此功能eer_eval(y_true, probas)显然不是张量流样式。因此,也许最好让钩子计算y_trueprobas并将numpy的值赋予eer_eval()

_EERHook中:

def before_run(self, run_context):
    return tf.train.SessionRunArgs((self.labels, self.probas))

def after_run(self,
            run_context,  # pylint: disable=unused-argument
            run_values):
    results = run_values.results
    print('labels:', results[0])
    print('probas:', results[1])
    # err_eval(results[0], results[1])