回调中的Keras指标添加到了Tensorflow图

时间:2019-05-18 16:02:01

标签: python tensorflow keras

我正在用Keras训练模型,并定义了一个回调,在每个回调之后我都会计算一些指标。我使用keras.metrics之类的mean_absolute_error中的函数。我的回调如下:

class CalcMae(Callback): 

    def __init__(self):
        placeholder = np.zeros((batch_size, 1))
        self.val_y_true = K.variable(placeholder, name='val_y_true')
        self.val_y_pred = K.variable(placeholder, name='val_y_pred')

    def on_epoch_end(self, epoch, logs={}):
        # get some values for y_true and y_pred
        K.set_value(self.val_y_true, local_val_y_true)
        K.set_value(self.val_y_pred, local_val_y_pred)

        val_mea = K.eval(keras.metrics.mean_absolute_error(self.val_y_true, self.val_y_pred)))

现在,每个纪元训练的速度都变慢了一点,我在使用tensorboard时,发现对mean_absolute_error的调用将节点添加到tensorflow图中。我的猜测是,图表在每个时期都会变大,从而导致时期变慢。

我的问题是:为什么在仅仅计算一个值时,调用mean_absolute_error(或其他度量/损耗)会在tensorflow图上添加节点?为此,最好定义自己的度量函数吗?

(我无法使用Keras本身的指标,因为我需要一些无法通过指标函数访问的额外数据)

修改

我通过numpy操作实现了我需要的损失和指标。现在培训还没有放缓(尚未)。我仍然想知道为什么简单的计算需要将节点添加到tensorflow图。

0 个答案:

没有答案