如何使用批处理将Keras ImageDataGenerator拟合为大型数据集

时间:2019-07-18 10:39:07

标签: python keras data-augmentation

我想使用Keras ImageDataGenerator进行数据扩充。 为此,我必须使用我的训练数据作为参数,在实例化的ImageDataGenerator对象上调用.fit()函数,如下所示。

image_datagen = ImageDataGenerator(featurewise_center=True, rotation_range=90)
image_datagen.fit(X_train, augment=True)
train_generator = image_datagen.flow_from_directory('data/images')
model.fit_generator(train_generator, steps_per_epoch=2000, epochs=50)

但是,我的训练数据集太大,无法立即加载到内存中。 因此,我想使用训练数据的子集在多个步骤中适应生成器。

有没有办法做到这一点?

我想到的一个潜在解决方案是使用自定义生成器功能加载一批训练数据,并在一个循环中多次拟合图像生成器。但是,我不确定ImageDataGenerator的fit函数是否可以以这种方式使用,因为它可能会在每种拟合方法上重置。

作为其工作方式的示例:

def custom_train_generator():
    # Code loading training data subsets X_batch
    yield X_batch


image_datagen = ImageDataGenerator(featurewise_center=True, rotation_range=90)
gen = custom_train_generator()

for batch in gen:
    image_datagen.fit(batch, augment=True)

train_generator = image_datagen.flow_from_directory('data/images')
model.fit_generator(train_generator, steps_per_epoch=2000, epochs=50)

1 个答案:

答案 0 :(得分:1)

ImageDataGenerator为您提供了将数据批量加载的可能性;实际上,您可以在 fit_generator 方法中使用与ImageDataGenerator一起使用的参数 batch_size ;不需要(如果需要的话,只是为了好的实践)从头开始编写生成器。

来自Keras官方文档的示例:

datagen = ImageDataGenerator(
    featurewise_center=True,
    featurewise_std_normalization=True,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True)

# compute quantities required for featurewise normalization
# (std, mean, and principal components if ZCA whitening is applied)
datagen.fit(x_train)

# fits the model on batches with real-time data augmentation:
model.fit_generator(datagen.flow(x_train, y_train, batch_size=32),
                    steps_per_epoch=len(x_train) / 32, epochs=epochs)

我建议阅读这篇有关ImageDataGenenerator和Augmentation的出色文章:https://machinelearningmastery.com/how-to-configure-image-data-augmentation-when-training-deep-learning-neural-networks/

您的问题的解决方案在于以下代码行(简单流或flow_from_directory):

# prepare iterator
it = datagen.flow(samples, batch_size=1)