我想实施一个两步学习过程,其中:
1)使用损失函数loss_1
对模型进行几个时期的预训练
2)将损失函数更改为loss_2
,然后继续进行微调训练
目前,我的方法是:
model.compile(optimizer=opt, loss=loss_1, metrics=['accuracy'])
model.fit_generator(…)
model.compile(optimizer=opt, loss=loss_2, metrics=['accuracy’])
model.fit_generator(…)
请注意,优化器保持不变,仅损失函数发生变化。我想平稳地继续训练,但是损失功能有所不同。根据{{3}},重新编译模型会失去优化器状态。问题:
a)即使使用相同优化器(例如Adam),我也会丢失优化器状态吗?
b)如果a)的答案是肯定的,那么关于如何在不重置优化器状态的情况下将损失函数更改为新函数的任何建议?
编辑:
正如西蒙·卡比(Simon Caby)所建议并基于this post,我创建了一个自定义损失函数,其中包含两个根据历元数的损失计算。但是,它对我不起作用。我的方法:
def loss_wrapper(t_change, current_epoch):
def custom_loss(y_true, y_pred):
c_epoch = K.get_value(current_epoch)
if c_epoch < t_change:
# compute loss_1
else:
# compute loss_2
return custom_loss
在初始化current_epoch
之后,我如下编译:
current_epoch = K.variable(0.)
model.compile(optimizer=opt, loss=loss_wrapper(5, current_epoch), metrics=...)
要更新current_epoch
,我创建一个回调:
class NewCallback(Callback):
def __init__(self, current_epoch):
self.current_epoch = current_epoch
def on_epoch_end(self, epoch, logs={}):
K.set_value(self.current_epoch, epoch)
model.fit_generator(..., callbacks=[NewCallback(current_epoch)])
回调在每个时期正确更新self.current_epoch
。但是更新未达到自定义丢失功能。相反,current_epoch
会永久保留初始化值,并且永远不会执行loss_2
。
欢迎任何建议,谢谢!
答案 0 :(得分:1)
我的答案: a)是的,您可能应该制作自己的学习率调度程序以控制它:
keras.callbacks.LearningRateScheduler(schedule, verbose=0)
b)是的,您可以创建自己的损失函数,包括在两种不同损失方法之间变动的函数。请参阅:“高级Keras -构建复杂的自定义损失和指标” https://towardsdatascience.com/advanced-keras-constructing-complex-custom-losses-and-metrics-c07ca130a618
答案 1 :(得分:0)
如果您更改:
def loss_wrapper(t_change, current_epoch):
def custom_loss(y_true, y_pred):
c_epoch = K.get_value(current_epoch)
if c_epoch < t_change:
# compute loss_1
else:
# compute loss_2
return custom_loss
收件人:
def loss_wrapper(t_change, current_epoch):
def custom_loss(y_true, y_pred):
# compute loss_1 and loss_2
bool_case_1=K.less(current_epoch,t_change)
num_case_1=K.cast(bool_case_1,"float32")
loss = (num_case_1)*loss_1 + (1-num_case_1)*loss_2
return loss
return custom_loss
有效。
从本质上讲,我们需要将python代码转换为后端函数的组合,以使丢失工作而不必在model.compile(...)
的重新编译中进行更新。我对这些骇客不满意,并希望可以在回调中设置model.loss
而不用在以后重新编译model.compile(...)
(因为优化器状态被重置)。