当精度已经达到1.0时,如何停止Keras Training?我尝试监视损失值,但是在精度已经达到1的情况下,我没有尝试停止训练。
我很幸运地尝试了以下代码:
stopping_criterions =[
EarlyStopping(monitor='loss', min_delta=0, patience = 1000),
EarlyStopping(monitor='acc', base_line=1.0, patience =0)
]
model.summary()
model.compile(Adam(), loss='binary_crossentropy', metrics=['accuracy'])
model.fit(scaled_train_samples, train_labels, batch_size=1000, epochs=1000000, callbacks=[stopping_criterions], shuffle = True, verbose=2)
更新:
即使精度仍然不是1.0,训练也会立即在第一个纪元停止。
请帮助。
答案 0 :(得分:6)
据我所知,在基线回调中使用EarlyStopping并不能解决问题。 “基线”是您应该继续训练所要监视的变量的最小值(此处为准确性)。这里的基准是1.0,在第一个时期结束时,基线小于“准确性”(显然,您不能期望在第一个时期本身的“准确性”为1.0),并且由于耐心设置为零,因此训练停止在第一个纪元本身,因为基线大于准确性。 使用自定义回调在这里可以完成工作。
class MyThresholdCallback(tf.keras.callbacks.Callback):
def __init__(self, threshold):
super(MyThresholdCallback, self).__init__()
self.threshold = threshold
def on_epoch_end(self, epoch, logs=None):
accuracy = logs["acc"]
if accuracy >= self.threshold:
self.model.stop_training = True
并在model.fit中调用回调
callback=MyThresholdCallback(threshold=1.0)
model.fit(scaled_train_samples, train_labels, batch_size=1000, epochs=1000000, callbacks=[callback], shuffle = True, verbose=2)
答案 1 :(得分:3)
更新:我不知道为什么EarlyStopping
在这种情况下不起作用。相反,我定义了一个自定义回调,当acc
(或val_acc
)达到指定的基线时,该回调将停止训练:
from keras.callbacks import Callback
class TerminateOnBaseline(Callback):
"""Callback that terminates training when either acc or val_acc reaches a specified baseline
"""
def __init__(self, monitor='acc', baseline=0.9):
super(TerminateOnBaseline, self).__init__()
self.monitor = monitor
self.baseline = baseline
def on_epoch_end(self, epoch, logs=None):
logs = logs or {}
acc = logs.get(self.monitor)
if acc is not None:
if acc >= self.baseline:
print('Epoch %d: Reached baseline, terminating training' % (epoch))
self.model.stop_training = True
您可以像这样使用它:
callbacks = [TerminateOnBaseline(monitor='acc', baseline=0.8)]
callbacks = [TerminateOnBaseline(monitor='val_acc', baseline=0.95)]
注意:此解决方案无效。
如果要在训练(或验证)准确度完全达到100%时停止训练,请使用EarlyStopping
回调并将baseline
参数设置为1.0,并将patience
设置为零:
EarlyStopping(monitor='acc', baseline=1.0, patience=0) # use 'val_acc' instead to monitor validation accuarcy
答案 2 :(得分:2)
名称baseline
具有误导性。尽管从下面的源代码很难理解,baseline
应该理解为:
虽然监控值比基准值差 1 ,但应继续训练最长patience
个时期。如果更好,请提高基线并重复。
1 ,即精度较低,损失较高。
EarlyStopping
的相关(已整理)源代码:
self.best = baseline # in initialization
...
def on_epoch_end(self, epoch, logs=None):
current = self.get_monitor_value(logs)
if self.monitor_op(current - self.min_delta, self.best): # read as `current > self.best` (for accuracy)
self.best = current
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
self.model.stop_training = True
那么你的例子
EarlyStopping(monitor='acc', base_line=1.0, patience=0)
的意思是:虽然监视的值比1.0差(它始终是),但要继续训练0个纪元(即立即终止)。
如果需要这些语义:
尽管监测值比基线差,但仍要继续训练。如果更好,请继续训练直到patience
连续纪元为止没有任何进展,并保留EarlyStopping
的所有功能,我可以建议:
class MyEarlyStopping(EarlyStopping):
def __init__(self, *args, **kw):
super().__init__(*args, **kw)
self.baseline_attained = False
def on_epoch_end(self, epoch, logs=None):
if not self.baseline_attained:
current = self.get_monitor_value(logs)
if current is None:
return
if self.monitor_op(current, self.baseline):
if self.verbose > 0:
print('Baseline attained.')
self.baseline_attained = True
else:
return
super(MyEarlyStopping, self).on_epoch_end(epoch, logs)