我知道使用fit_generator
方法训练时很容易提及时代的数量。我有很多要训练的图像,我不能一次使用数组加载它们,因为它显示MemoryError
。我需要在达到一定的验证准确度(例如98%)后停止培训。如果在给定的历元数量之后尚未达到准确度,则训练将停止。在Keras有什么办法吗?我正在使用Tensorflow后端。
编辑:我在Keras看过EarlyStopping
模块,但它只跟踪监控数量的变化。
答案 0 :(得分:5)
您可以从Keras获取EarlyStopping
的代码。
class EarlyStoppingByAccuracy(Callback):
def __init__(self, monitor='accuracy', value=0.98, verbose=0):
super(Callback, self).__init__()
self.monitor = monitor
self.value = value
self.verbose = verbose
def on_epoch_end(self, epoch, logs={}):
current = logs.get(self.monitor)
if current is None:
warnings.warn("Early stopping requires %s available!" % self.monitor, RuntimeWarning)
if current >= self.value:
if self.verbose > 0:
print("Epoch %05d: early stopping THR" % epoch)
self.model.stop_training = True
自定义提前停止可以像下面这样使用
callbacks = [
EarlyStoppingByAccuracy(monitor='accuracy', value=0.98, verbose=1),
ModelCheckpoint(kfold_weights_path, monitor='val_loss', save_best_only=True, verbose=0),
]
model.fit(X_train.astype('float32'), Y_train, batch_size=batch_size, nb_epoch=nb_epoch,
shuffle=True, verbose=1, validation_data=(X_valid, Y_valid),
callbacks=callbacks)