如何:fit_generator in keras

时间:2017-10-04 16:52:02

标签: python keras

我对如何在fit_generator中使用keras感到有些困惑。

举例说:

  • 我们有10000个数据点
  • 我们希望运行10个纪元
  • 批量大小为512

我们只使用fit

x, y = load_data()
model.fit(x=x, y=y, batch_size=512, epochs=10)

其中load_data加载所有数据。

现在如何对fit_generator执行相同操作。

我不清楚使用fit_generator时如何处理它。如果我有以下发电机:

def data_generator():
    for x, y in load_data_per_line():
        yield x, y

在上面的生成器中每次yields一个数据点。和

def data_generator_2():
    x_output = []
    y_output = []
    i = 0
    for x, y in load_data_per_line():
        x_output[i] = x
        y_output[i] = y
        i = i + 1
        if i == batch_size:
           yield x_output, y_output
           i = 0
           x_output = []
           y_output = []

在上面的生成器中,每次yields批量大小数据点(在这种情况下为512)。

fit实现相同但使用fit_generator

model.fit_generator(data_generator(), steps_per_epoch=10000 / 512, epochs=10)

model.fit_generator(data_generator_2(), steps_per_epoch=10000 / 512, epochs=10)

或两者都错(fit_generatordata_generator s)?如果它们中的任何一个是正确的,那么是否保证所有数据点都将被处理并且还要按顺序处理?

任何见解都很有用

1 个答案:

答案 0 :(得分:2)

生成器2几乎没问题,但最好还是返回numpy数组:

yield np.asarray(x_output),np.asarray(y_output)

此外,它应该是无限的:

while True: 

    #the code inside to loop infinitely

第一个不会返回批次并且会失败。

您可能在steps_per_epoch中遇到问题,因为10000不是512的倍数。您需要整数步骤。您可以在生成器内检查if i == 10000:并将最小批次作为最后一批传递。

然后你有(10000 //512) + (10000 % 512)个步骤或批次。

将按顺序读取所有批次,但keras会自动对这些批次的内容进行随机播放,请使用suffle=False。如果使用多线程(不是这种情况),则需要创建线程安全生成器或使用keras Sequence