我想为我的卷积神经网络找到最佳学习率。但是,我使用的是自定义数据生成器,因为无法将整个数据集放入内存中。因此,我将以32个训练示例的形式读取数据。我调用自定义数据生成器的方式是:
params = {'batch_size': 32,
'n_classes': 5,
'shuffle': True}
training_generator = DataGenerator(partition_snp_final['train'], partition_pos_final['train'],labels_final, **params)
validation_generator = DataGenerator(partition_snp_final['valid'],partition_pos_final['valid'], labels_final, **params)
然后我像下面这样调用fit_generator:
callbacks_list = [earlystop,checkpoint]
model.fit_generator(generator=training_generator,epochs=10,validation_data=validation_generator,use_multiprocessing=True,
workers=6,callbacks=callbacks_list)
但是,现在我想将LR_finder包含在我的callbacks_list中,但是我不知道如何使其与fit_generator一起使用,因为在LR_finder中,我使用了一个我想成为的step_size:
但是训练和验证的步长会有所不同,因此在使用fit_generator时如何使LR_finder代码正常工作。
我的LR_finder代码是:
from keras.callbacks import Callback
import keras.backend as K
class LR_Finder(Callback):
def __init__(self, start_lr=1e-5, end_lr=10, step_size=None, beta=.98):
super().__init__()
self.start_lr = start_lr
self.end_lr = end_lr
self.step_size = step_size
self.beta = beta
self.lr_mult = (end_lr/start_lr)**(1/step_size)
def on_train_begin(self, logs=None):
self.best_loss = 1e9
self.avg_loss = 0
self.losses, self.smoothed_losses, self.lrs, self.iterations = [], [], [], []
self.iteration = 0
logs = logs or {}
K.set_value(self.model.optimizer.lr, self.start_lr)
def on_batch_end(self, epoch, logs=None):
logs = logs or {}
loss = logs.get('loss')
self.iteration += 1
self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * loss
smoothed_loss = self.avg_loss / (1 - self.beta**self.iteration)
print('smoothed loss : ',smoothed_loss)
# Check if the loss is not exploding
if self.iteration>1 and smoothed_loss > self.best_loss * 4:
self.model.stop_training = True
return
if smoothed_loss < self.best_loss or self.iteration==1:
self.best_loss = smoothed_loss
lr = self.start_lr * (self.lr_mult**self.iteration)
self.losses.append(loss)
self.smoothed_losses.append(smoothed_loss)
self.lrs.append(lr)
self.iterations.append(self.iteration)
K.set_value(self.model.optimizer.lr, lr)
对于如何修改此代码,使其适用于fit_generator,我们将不胜感激。