我正在用Keras训练模型,并定义了一个回调,在每个回调之后我都会计算一些指标。我使用keras.metrics
之类的mean_absolute_error
中的函数。我的回调如下:
class CalcMae(Callback):
def __init__(self):
placeholder = np.zeros((batch_size, 1))
self.val_y_true = K.variable(placeholder, name='val_y_true')
self.val_y_pred = K.variable(placeholder, name='val_y_pred')
def on_epoch_end(self, epoch, logs={}):
# get some values for y_true and y_pred
K.set_value(self.val_y_true, local_val_y_true)
K.set_value(self.val_y_pred, local_val_y_pred)
val_mea = K.eval(keras.metrics.mean_absolute_error(self.val_y_true, self.val_y_pred)))
现在,每个纪元训练的速度都变慢了一点,我在使用tensorboard时,发现对mean_absolute_error的调用将节点添加到tensorflow图中。我的猜测是,图表在每个时期都会变大,从而导致时期变慢。
我的问题是:为什么在仅仅计算一个值时,调用mean_absolute_error(或其他度量/损耗)会在tensorflow图上添加节点?为此,最好定义自己的度量函数吗?
(我无法使用Keras本身的指标,因为我需要一些无法通过指标函数访问的额外数据)
修改
我通过numpy操作实现了我需要的损失和指标。现在培训还没有放缓(尚未)。我仍然想知道为什么简单的计算需要将节点添加到tensorflow图。