尽早停止后如何获得最佳模型而不将其保存到文件中?

时间:2019-06-30 13:52:30

标签: python keras

如果提前停止训练,则最佳模型将保存到文件best_model.h5中。但是从文件加载模型需要相对较长的时间。有没有办法以另一种方式获得最佳模型?

例如,通过在内存中创建文件并从内存中读取文件。或通过将每个时期的每个模型放入列表中,然后使用EarlyStopping.stopped_epoch从列表中获取最佳模型来访问相应的列表项。

import numpy as np
import pandas as pd
from keras.models import Sequential
from keras.layers import Dense
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras.models import load_model
from sklearn.model_selection import train_test_split

df = pd.DataFrame(np.random.randint(0,100,size=(1000, 3))/100, columns=['x_1', 'x_2','y'])
x_train, x_test, y_train, y_test = train_test_split(df[['x_1', 'x_2']], df[['y']], test_size=0.2, random_state=0)
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.2, random_state=0)

callbacks = [EarlyStopping(monitor='val_loss', patience=2),
             ModelCheckpoint(filepath='best_model.h5', monitor='val_loss', save_best_only=True)]

model = Sequential()
model.add(Dense(units=1, activation='linear', input_dim=2))
model.add(Dense(units=1))
model.compile(loss='mean_squared_error',optimizer='adam',metrics=['mean_absolute_error', 'mean_squared_error'])
model.fit(x_train, y_train, epochs=100, batch_size=32, callbacks=callbacks, validation_data=(x_val, y_val))

model = load_model('best_model.h5')

print(model.evaluate(x_test, y_test, batch_size=32))

1 个答案:

答案 0 :(得分:0)

您需要在某个时候存储模型的当前状态,因为进一步的训练会改变该状态。

如果您不打算保存(使用ModelCheckpoint的最简单选项),那么您需要一个自定义的回调函数来执行stored_weights = model.get_weights()

选项1

您可以尝试在save_weights_only=True中使用ModelCheckpoint,并在以后的模型中使用model.load_weights(path)。这不会创建新模型,只会加载权重。

选项2

即使您仍然认为它太慢,也可以创建自己的回调并使用RAM存储权重:

from keras.callbacks import LambdaCallback

bestLoss = 1000000000000000000
bestWeights = None

def storeWeights(e, logs):
    if logs['val_loss'] < bestLoss:
        bestLoss = logs['val_loss']
        bestWeights = model.get_weights()

callbacks = [EarlyStopping(monitor='val_loss', patience=2), 
             LambdaCallback(on_epoch_end=storeWeights)]

#train here
model.fit...........................
#finished train

model.set_weights(bestWeights)