keras中的fit_generator:指定的batch_size在哪里?

时间:2017-05-04 10:21:15

标签: tensorflow keras

您好我不懂keras fit_generator docs。

我希望我的困惑是理性的。

有一个batch_size以及批量培训的概念。使用model_fit(),我指定的batch_size为128。

对我而言,这意味着我的数据集将一次输入128个样本,从而大大减轻了内存。只要我有时间等待,它应该允许训练1亿个样本数据集。毕竟,keras一次只能“处理”128个样本。正确?

但我非常怀疑单独指定batch_size并不能做我想做的事情。仍在使用大量内存。为了我的目标,我需要分别训练128个例子。

所以我猜这是fit_generator的作用。我真的想问为什么batch_size实际上不起作用,顾名思义?

更重要的是,如果需要fit_generator,我在哪里指定batch_size?文档说无限循环。 生成器循环遍历每一行一次。如何一次循环128个样本并记住我上次停止的位置并在下次keras要求下一批次的起始行号时调用它(在第一批完成后将是第129行)。

2 个答案:

答案 0 :(得分:0)

首先,keras batch_size确实运行良好。如果您正在使用GPU,您应该知道该模型对于keras来说可能非常繁重,特别是如果您使用的是复发单元格。如果您正在使用CPU,整个程序将加载到内存中,批量大小不会对内存产生太大影响。如果您使用fit(),则整个数据集可能已加载到内存中,keras会在每个步骤生成批次。预测将要使用的内存量非常困难。

至于fit_generator()方法,你应该构建一个python生成器函数(使用yield而不是return),每一步产生一个批处理。 yield应该处于无限循环中(我们经常使用while true: ...)。

您是否有一些代码来说明您的问题?

答案 1 :(得分:0)

您将需要以某种方式在生成器内部处理批量大小。这是生成随机批次的示例:

import numpy as np
data = np.arange(100)
data_lab = data%2
wholeData = np.array([data, data_lab])
wholeData = wholeData.T

def data_generator(all_data, batch_size = 20):

    while True:        

        idx = np.random.randint(len(all_data), size=batch_size)

        # Assuming the last column contains labels
        batch_x = all_data[idx, :-1]
        batch_y = all_data[idx, -1]

        # Return a tuple of (Xs,Ys) to feed the model
        yield(batch_x, batch_y)

print([x for x in data_generator(wholeData)])