无法将多个回调传递给keras模型

时间:2020-06-14 06:14:44

标签: python tensorflow keras

我正在构建一个keras LSTM模型,并且在初次通过时,我发现它有点过拟合数据,所以我初始化了2个回调-一个用于控制变量学习率的回调,另一个用于提早停止:

    def _initialise_callback(self):

        # Ensure learning rate decreases with the epoch number
        learning_rate = 0.1
        decay_rate = learning_rate / self.epochs
        momentum = 0.8
        self.sgd = SGD(lr=learning_rate, momentum=momentum, decay=decay_rate, nesterov=False)

        #Allow model to stop early to prevent overfitting
        self.early_stopping = EarlyStopping(monitor='loss', patience=3)

但是由于某种原因,我似乎无法将它们都传递给fit()方法。我要做的是:

    def fit(self):
        self.model.fit(self.train_set, epochs=self.epochs, verbose=2, shuffle=False,
                       callbacks=[self.early_stopping, self.sgd],
                       use_multiprocessing=False)

并导致以下错误:


  File "<ipython-input-1-1532e4234d2a>", line 1, in <module>
    runfile('C:/VULCAN_HOME/sampling_bias/bias_LSTM.py', wdir='C:/VULCAN_HOME/sampling_bias')

  File "C:\ProgramData\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py", line 710, in runfile
    execfile(filename, namespace)

  File "C:\ProgramData\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py", line 101, in execfile
    exec(compile(f.read(), filename, 'exec'), namespace)

  File "C:/VULCAN_HOME/sampling_bias/bias_LSTM.py", line 174, in <module>
    predictor.fit()

  File "C:/VULCAN_HOME/sampling_bias/bias_LSTM.py", line 164, in fit
    use_multiprocessing=False)

  File "C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\training.py", line 1147, in fit
    initial_epoch=initial_epoch)

  File "C:\ProgramData\Anaconda3\lib\site-packages\keras\legacy\interfaces.py", line 91, in wrapper
    return func(*args, **kwargs)

  File "C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\training.py", line 1732, in fit_generator
    initial_epoch=initial_epoch)

  File "C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\training_generator.py", line 100, in fit_generator
    callbacks.set_model(callback_model)

  File "C:\ProgramData\Anaconda3\lib\site-packages\keras\callbacks\callbacks.py", line 68, in set_model
    callback.set_model(model)

AttributeError: 'SGD' object has no attribute 'set_model'

另一方面,如果我尝试仅通过sgd或仅通过early_stopping,则一切正常。有人知道这里发生了什么吗?

1 个答案:

答案 0 :(得分:2)

SGD优化器应该作为参数传递给compile方法,如图here所示,而不是作为fit方法的回调参数传递。我已经在下面修改了您的代码:

def fit(self):
        self.model.fit(self.train_set, epochs=self.epochs, verbose=2, shuffle=False,
                       callbacks=[self.early_stopping],
                       use_multiprocessing=False)

当您编译模型时,请通过优化器

self.model.compile(optimizer=self.sgd, **kwargs) 

希望这会有所帮助!