使用CV网格搜索的Keras回调

时间:2018-03-04 09:13:35

标签: python keras

我正在尝试实现CV网格搜索来调整Keras模型的超参数。这是我的代码(运行没有错误,但没有正确处理回调):

def create_model(optimizer,lstm_nodes):

    model = Sequential()
    model.add(LSTM(lstm_nodes, dropout=0.25))
    model.add(Dense(5, activation='softmax'))
    model.compile(loss='categorical_crossentropy', 
              optimizer=optimizer,
              metrics=['accuracy'])
    return model


model = KerasClassifier(build_fn=create_model, epochs=15, verbose=0)

#define the grid search parameters
optimizer = ['Adam','SGD']
lstm_nodes = [12,18,24]
param_grid = dict(optimizer=optimizer,
                  lstm_nodes=lstm_nodes)

###### here is where the confusion happens ######
filepath = "weights-improvement-{epoch:02d}-{optimizer}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=0, save_best_only=True, mode='max')
csv_logger = CSVLogger('log.csv', append=False, separator=',')
callback=[csv_logger,checkpoint]
#################################################

grid = GridSearchCV(estimator=model,cv=5, param_grid=param_grid, n_jobs=-1)
grid_result = grid.fit(xMat, yMat,validation_split = 0.1,callbacks=fit_params) 

我遇到回调问题:

  1. 为每个参数组合保存每个CV折叠的最佳模型。
  2. 正确记录每个参数组合。
  3. 任何帮助将不胜感激!

1 个答案:

答案 0 :(得分:1)

似乎没有办法用CV正确检查模型。但是,如果记录了所有内容 - 您可以解析文件并找到最佳参数。以下是它的完成方式:

定义一个类,指定在Epoch-End上执行的操作:

class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.losses = []

    def on_epoch_end(self, batch, logs={}):
        with open('somefile.txt', 'a') as f:
            stats = []
            stats.append(str(batch))
            stats.append('Optimizer,' + self.model.optimizer.__class__.__name__)
            stats.append('Batch_size,' + str(self.params['batch_size']))
            stats.append('accuracy,'+str(logs.get('accuracy')))
            stats.append('val_loss,'+str(logs.get('val_loss')))
            f.write(','.join(stats)+'\n')

然后初始化历史对象并将其添加到回调列表中:

history = LossHistory()
grid = GridSearchCV(estimator=model,cv=5, param_grid=param_grid, n_jobs=-1)
grid_result = grid.fit(xMat, yMat,validation_split = 0.1,callbacks=[history]) 

根据您的需要以及文件的格式(这只是一个示例)修改您需要在LossHistory类中捕获的params。 Here is more documentation on callbacks in keras.