SessionRunHook在运行后返回空的SessionRunValues

时间:2018-11-08 23:12:43

标签: python-3.x tensorflow keras tensorflow-estimator

我正在尝试编写一个钩子,该钩子将允许我计算一些全局指标(而不是按批处理的指标)。作为原型,我认为我将获得一个简单的连接并运行,它将捕获并记住真实的积极信息。看起来像这样:

class TPHook(tf.train.SessionRunHook):

    def after_create_session(self, session, coord):
        print("Starting Hook")

        tp_name = 'metrics/f1_macro/TP'
        self.tp = []
        self.args = session.graph.get_operation_by_name(tp_name)
        print(f"Got Args: {self.args}")

    def before_run(self, run_context):
        print("Starting Before Run")
        return tf.train.SessionRunArgs(self.args)

    def after_run(self, run_context, run_values):
        print("After Run")
        print(f"Got Values: {run_values.results}")

但是,在挂钩的“ after_run”部分中返回的值始终为“无”。我在训练和评估阶段都对此进行了测试。我对SessionRunHooks应该如何工作有误解吗?


也许相关信息: 该模型是在keras中构建的,并通过keras.estimator.model_to_estimator()函数转换为估计量。该模型已经过测试并且可以正常工作,并且我在钩子中尝试检索的操作已在以下代码块中定义:

def _f1_macro_vector(y_true, y_pred):
    """Computes the F1-score with Macro averaging.

    Arguments:
        y_true {tf.Tensor} -- Ground-truth labels
        y_pred {tf.Tensor} -- Predicted labels

    Returns:
        tf.Tensor -- The computed F1-Score
    """
    y_true = K.cast(y_true, tf.float64)
    y_pred = K.cast(y_pred, tf.float64)

    TP = tf.reduce_sum(y_true * K.round(y_pred), axis=0, name='TP')
    FN = tf.reduce_sum(y_true * (1 - K.round(y_pred)), axis=0, name='FN')
    FP = tf.reduce_sum((1 - y_true) * K.round(y_pred), axis=0, name='FP')

    prec = TP / (TP + FP)
    rec = TP / (TP + FN)

    # Convert NaNs to Zero
    prec = tf.where(tf.is_nan(prec), tf.zeros_like(prec), prec)
    rec = tf.where(tf.is_nan(rec), tf.zeros_like(rec), rec)

    f1 = 2 * (prec * rec) / (prec + rec)

    # Convert NaN to Zero
    f1 = tf.where(tf.is_nan(f1), tf.zeros_like(f1), f1)

    return f1

1 个答案:

答案 0 :(得分:0)

万一有人遇到相同的问题,我发现了如何重组程序以使其运行。尽管文档听起来像我可以将原始操作传递给SessionRunArgs,但似乎它需要实际的张量(也许这是我的误读)。 这非常容易实现-我只是将after_create_session代码更改为以下所示。

def after_create_session(self, session, coord):

    tp_name = 'metrics/f1_macro/TP'
    self.tp = []
    tp_tensor = session.graph.get_tensor_by_name(tp_name+':0')

    self.args = [tp_tensor]

此操作成功运行。