我有一个生成器函数,可以生成(输入,目标)元组,并使用Keras中的fit_generator()
方法在其上训练我的模型。
我的数据集分为9个相等部分。我希望使用fit_generator()
方法对数据集进行一劳永逸的交叉验证,并保持先前训练的学习参数不变。
我的问题是,将在模型上多次调用fit_generator()
使其从头开始重新学习先前训练和验证集上的学习参数,还是将这些学习参数保持不变以提高准确性?
经过一番挖掘,我发现Keras中的fit()
方法保留了此处Calling "fit" multiple times in Keras处的学习参数,但是我不确定fit_generator()
是否也会发生同样的情况可以将其用于数据的交叉验证。
我正在考虑实现交叉验证的伪代码如下:
class DatasetGenerator(Sequence):
def __init__(validation_id, mode):
#Some code
def __getitem__():
#The generator function
#Some code
return (inputs, targets)
for id in range(9):
train_set = DatasetGenerator(id, 'train')
#train_set contains all 8 parts leaving the id part out for validation.
validation_set = DatasetGenerator(id, 'val')
#val_set contains the id part.
history = model.fit_generator(train_set, epochs = 10, steps_per_epoch = 24000, validation_data = val_set, validation_steps = 3000)
print('History Dict:', history.history)
results = model.evaluate_generator(test_set, steps=steps)
print('Test loss, acc:', results)
model
会保持学习的参数完整并在for
循环的每次迭代中对其进行改进吗?
答案 0 :(得分:3)
fit
和fit_generator
在这方面的行为相同,再次调用它们将恢复以前训练过的权重。
还请注意,您要执行的操作不是交叉验证,而是进行真正的交叉验证,您需要为每个折叠训练一个模型,并且这些模型是完全独立的,而不是继续前一个折叠的训练。
答案 1 :(得分:1)
据我所知,它将保留以前训练有素的参数。另外,我认为您可以尝试通过修改Sequence的on_epoch_end()方法来完成。可能是这样的:
class DatasetGenerator(Sequence):
def __init__(self, id, mode):
self.id = id
self.mode = mode
self.current_epoch=0
#some code
def __getitem__(self, idx):
id = self.id
#Some code
return (inputs, targets)
def on_epoch_end():
self.current_epoch += 1
if self.current_epoch % 10 == 0:
self.id += 1