我的目标是在自定义损失函数中获取当前纪元号。这对于i)根据时期使用不同的损失函数方法,或ii)在损失计算中使用一些与时期相关的参数alpha
可能是有用的。例如:
current_epoch = K.variable(0.)
def custom_loss(y_true, y_pred):
# i) use different loss function based on epoch
c_epoch = K.get_value(current_epoch)
if c_epoch < t_change:
# compute loss_1
else:
# compute loss_2
# ii) also useful to do: result = (1 - alpha)*loss_1 + alpha*loss_2
return result
没有成功,我尝试使用Lambda回调获取当前纪元并将其传递给loss函数。特别是,我已经尝试了here和here提出的方法。本质上,它包括:
from keras.callbacks import LambdaCallback
def get_epoch(epoch):
K.set_value(current_epoch, epoch)
# or K.set_value(alpha, functionOf(epoch))
epochchanger = LambdaCallback(on_epoch_begin=get_epoch)
最后,编译并拟合:
model.compile(loss=custom_loss, metrics=...)
model.fit_generator(..., callbacks=[epochchanger])
var current_epoch
在get_epoch()
中的每个时期正确更新。但是更新后的current_epoch
无法达到custom_loss
功能。相反,current_epoch
函数中的custom_loss
永远保留初始化值。
关于如何在损失函数中获取更新的current_epoch
的任何建议?注意:重新编译模型是更改损失函数的一种方法,但是我尝试不这样做,因为更改优化器状态不是最佳选择。谢谢!