我正在构建一个keras LSTM模型,并且在初次通过时,我发现它有点过拟合数据,所以我初始化了2个回调-一个用于控制变量学习率的回调,另一个用于提早停止:
def _initialise_callback(self):
# Ensure learning rate decreases with the epoch number
learning_rate = 0.1
decay_rate = learning_rate / self.epochs
momentum = 0.8
self.sgd = SGD(lr=learning_rate, momentum=momentum, decay=decay_rate, nesterov=False)
#Allow model to stop early to prevent overfitting
self.early_stopping = EarlyStopping(monitor='loss', patience=3)
但是由于某种原因,我似乎无法将它们都传递给fit()
方法。我要做的是:
def fit(self):
self.model.fit(self.train_set, epochs=self.epochs, verbose=2, shuffle=False,
callbacks=[self.early_stopping, self.sgd],
use_multiprocessing=False)
并导致以下错误:
File "<ipython-input-1-1532e4234d2a>", line 1, in <module>
runfile('C:/VULCAN_HOME/sampling_bias/bias_LSTM.py', wdir='C:/VULCAN_HOME/sampling_bias')
File "C:\ProgramData\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py", line 710, in runfile
execfile(filename, namespace)
File "C:\ProgramData\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py", line 101, in execfile
exec(compile(f.read(), filename, 'exec'), namespace)
File "C:/VULCAN_HOME/sampling_bias/bias_LSTM.py", line 174, in <module>
predictor.fit()
File "C:/VULCAN_HOME/sampling_bias/bias_LSTM.py", line 164, in fit
use_multiprocessing=False)
File "C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\training.py", line 1147, in fit
initial_epoch=initial_epoch)
File "C:\ProgramData\Anaconda3\lib\site-packages\keras\legacy\interfaces.py", line 91, in wrapper
return func(*args, **kwargs)
File "C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\training.py", line 1732, in fit_generator
initial_epoch=initial_epoch)
File "C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\training_generator.py", line 100, in fit_generator
callbacks.set_model(callback_model)
File "C:\ProgramData\Anaconda3\lib\site-packages\keras\callbacks\callbacks.py", line 68, in set_model
callback.set_model(model)
AttributeError: 'SGD' object has no attribute 'set_model'
另一方面,如果我尝试仅通过sgd
或仅通过early_stopping
,则一切正常。有人知道这里发生了什么吗?
答案 0 :(得分:2)
SGD
优化器应该作为参数传递给compile
方法,如图here所示,而不是作为fit方法的回调参数传递。我已经在下面修改了您的代码:
def fit(self):
self.model.fit(self.train_set, epochs=self.epochs, verbose=2, shuffle=False,
callbacks=[self.early_stopping],
use_multiprocessing=False)
当您编译模型时,请通过优化器
self.model.compile(optimizer=self.sgd, **kwargs)
希望这会有所帮助!