如何使Keras仅根据验证数据计算特定指标?

时间:2019-06-30 16:44:55

标签: python tensorflow keras

我在TensorFlow 1.14.0中使用tf.keras。我实现了一个自定义指标,该指标需要大量的计算,如果将其简单地添加到model.compile(..., metrics=[...])提供的指标列表中,则会减慢训练过程。

我如何让Keras在训练迭代期间跳过指标的计算,而在每个时期结束时在验证数据上进行计算(并打印)?

1 个答案:

答案 0 :(得分:2)

为此,您可以在指标计算中创建一个tf.Variable,以确定该计算是否继续进行,然后在使用回调运行测试时对其进行更新。例如

class MyCustomMetric(tf.keras.metrics.Metrics):

    def __init__(self, **kwargs):
        # Initialise as normal and add flag variable for when to run computation
        super(MyCustomMetric, self).__init__(**kwargs)
        self.metric_variable = self.add_weight(name='metric_varaible', initializer='zeros')
        self.update_metric = tf.Variable(False)

    def update_state(self, y_true, y_pred, sample_weight=None):
        # Use conditional to determine if computation is done
        if self.update_metric:
            # run computation
            self.metric_variable.assign_add(computation_result)

    def result(self):
        return self.metric_variable

    def reset_states(self):
        self.metric_variable.assign(0.)

class ToggleMetrics(tf.keras.callbacks.Callback):
    '''On test begin (i.e. when evaluate() is called or 
     validation data is run during fit()) toggle metric flag '''
    def on_test_begin(self, logs):
        for metric in self.model.metrics:
            if 'MyCustomMetric' in metric.name:
                metric.on.assign(True)
    def on_test_end(self,  logs):
        for metric in self.model.metrics:
            if 'MyCustomMetric' in metric.name:
                metric.on.assign(False)