在Keras上循环使用model.fit是否合乎逻辑?

时间:2018-05-21 12:17:38

标签: keras large-data batching

在Keras中执行以下操作以便不会耗尽内存是否合乎逻辑?

for path in ['xaa', 'xab', 'xac', 'xad']:
    x_train, y_train = prepare_data(path)
    model.fit(x_train, y_train, batch_size=50, epochs=20, shuffle=True)

model.save('model')

2 个答案:

答案 0 :(得分:2)

但是,如果每次迭代生成一个批次,则更喜欢fit。这消除了model.fit_generator()带来的一些开销。

您还可以尝试创建生成器并使用def dataGenerator(pathes, batch_size): while True: #generators for keras must be infinite for path in pathes: x_train, y_train = prepare_data(path) totalSamps = x_train.shape[0] batches = totalSamps // batch_size if totalSamps % batch_size > 0: batches+=1 for batch in range(batches): section = slice(batch*batch_size,(batch+1)*batch_size) yield (x_train[section], y_train[section])

gen = dataGenerator(['xaa', 'xab', 'xac', 'xad'], 50)
model.fit_generator(gen,
                    steps_per_epoch = expectedTotalNumberOfYieldsForOneEpoch
                    epochs = epochs)

创建并使用:

HAVING

答案 1 :(得分:1)

我建议在Github上查看这个thread

您确实可以考虑使用model.fit(),但这样可以让训练更加稳定:

for epoch in range(20):
    for path in ['xaa', 'xab', 'xac', 'xad']:
        x_train, y_train = prepare_data(path)
        model.fit(x_train, y_train, batch_size=50, epochs=epoch+1, initial_epoch=epoch, shuffle=True)

通过这种方式,您可以在每个纪元上迭代所有数据,而不是在切换之前在部分数据上迭代20个纪元。

正如线程中所讨论的,另一种解决方案是开发自己的数据生成器并将其与model.fit_generator()一起使用。