获取损失函数中的当前纪元

时间:2019-04-02 01:02:37

标签: keras deep-learning loss-function

我的目标是在自定义损失函数中获取当前纪元号。这对于i)根据时期使用不同的损失函数方法,或ii)在损失计算中使用一些与时期相关的参数alpha可能是有用的。例如:

current_epoch = K.variable(0.)

def custom_loss(y_true, y_pred):
   # i) use different loss function based on epoch
   c_epoch = K.get_value(current_epoch)
   if c_epoch < t_change:
       # compute loss_1
   else:
       # compute loss_2

   # ii) also useful to do: result = (1 - alpha)*loss_1 + alpha*loss_2
   return result

没有成功,我尝试使用Lambda回调获取当前纪元并将其传递给loss函数。特别是,我已经尝试了herehere提出的方法。本质上,它包括:

from keras.callbacks import LambdaCallback

def get_epoch(epoch):
    K.set_value(current_epoch, epoch)
    # or K.set_value(alpha, functionOf(epoch))

epochchanger = LambdaCallback(on_epoch_begin=get_epoch)

最后,编译并拟合:

model.compile(loss=custom_loss, metrics=...)
model.fit_generator(..., callbacks=[epochchanger])

var current_epochget_epoch()中的每个时期正确更新。但是更新后的current_epoch无法达到custom_loss功能。相反,current_epoch函数中的custom_loss永远保留初始化值。

关于如何在损失函数中获取更新的current_epoch的任何建议?注意:重新编译模型是更改损失函数的一种方法,但是我尝试不这样做,因为更改优化器状态不是最佳选择。谢谢!

0 个答案:

没有答案