在Keras中多次调用“ fit_generator()”

时间:2019-06-13 04:46:54

标签: python tensorflow keras deep-learning computer-vision

我有一个生成器函数,可以生成(输入,目标)元组,并使用Keras中的fit_generator()方法在其上训练我的模型。

我的数据集分为9个相等部分。我希望使用fit_generator()方法对数据集进行一劳永逸的交叉验证,并保持先前​​训练的学习参数不变。

我的问题是,将在模型上多次调用fit_generator()使其从头开始重新学习先前训练和验证集上的学习参数,还是将这些学习参数保持不变以提高准确性?

经过一番挖掘,我发现Keras中的fit()方法保留了此处Calling "fit" multiple times in Keras处的学习参数,但是我不确定fit_generator()是否也会发生同样的情况可以将其用于数据的交叉验证。

我正在考虑实现交叉验证的伪代码如下:

class DatasetGenerator(Sequence):
    def __init__(validation_id, mode):
        #Some code

    def __getitem__():
        #The generator function

        #Some code

        return (inputs, targets)

for id in range(9):

    train_set = DatasetGenerator(id, 'train') 
    #train_set contains all 8 parts leaving the id part out for validation.

    validation_set = DatasetGenerator(id, 'val')
    #val_set contains the id part.

    history = model.fit_generator(train_set, epochs = 10, steps_per_epoch = 24000, validation_data = val_set, validation_steps = 3000)

print('History Dict:', history.history)
results = model.evaluate_generator(test_set, steps=steps)
print('Test loss, acc:', results)

model会保持学习的参数完整并在for循环的每次迭代中对其进行改进吗?

2 个答案:

答案 0 :(得分:3)

fitfit_generator在这方面的行为相同,再次调用它们将恢复以前训练过的权重。

还请注意,您要执行的操作不是交叉验证,而是进行真正的交叉验证,您需要为每个折叠训练一个模型,并且这些模型是完全独立的,而不是继续前一个折叠的训练。

答案 1 :(得分:1)

据我所知,它将保留以前训练有素的参数。另外,我认为您可以尝试通过修改Sequence的on_epoch_end()方法来完成。可能是这样的:

class DatasetGenerator(Sequence):
    def __init__(self, id, mode):
        self.id = id
        self.mode = mode
        self.current_epoch=0
        #some code

    def __getitem__(self, idx):
        id = self.id
        #Some code
        return (inputs, targets)

    def on_epoch_end():
        self.current_epoch += 1
        if self.current_epoch % 10 == 0:
            self.id += 1