无论如何,在KerasRegressor包装器中使用fit_generator()方法吗?

时间:2019-11-30 16:48:34

标签: keras scikit-learn lstm gridsearchcv

我正在尝试使用GridSearchCV调整LSTM模型的超参数,但使用了来自keras.preprocessing.sequence的TimeseriesGenerator。如何修改KerasRegressor包装器以容纳fit_generator()而不是fit()方法?

def create_model(layer1=50): 
    lstm_model = Sequential()
    lstm_model.add(LSTM(layer1, input_shape=(10,11)))
    lstm_model.add(Dense(11, activation='tanh'))
    lstm_model.compile(loss='mean_squared_error', optimizer='rmsprop')
    return lstm_model

model = KerasRegressor(build_fn=create_model, epochs=5, batch_size=10, verbose=0)

layer1 = [40,50]
param_grid = dict(layer1=layer1)

grid = GridSearchCV(estimator=model, param_grid=param_grid, n_jobs=-1, cv=3)
grid_result = grid.fit(X, y)

但是如何将KerasRegressor与类似

的生成器一起使用
from keras.preprocessing.sequence import TimeseriesGenerator
train_data_gen = TimeseriesGenerator(X, y,length=10,batch_size=100)

0 个答案:

没有答案