Keras:使用flow_from_directory为训练数据拟合图像增强

时间:2017-10-12 09:07:04

标签: machine-learning keras deep-learning

我想在Keras中使用图像增强。我目前的代码如下:

# define image augmentations
train_datagen = ImageDataGenerator(
featurewise_center=True,
featurewise_std_normalization=True,
zca_whitening=True)

# generate image batches from directory
train_datagen.flow_from_directory(train_dir)

当我使用此模型运行模型时,出现以下错误:

"ImageDataGenerator specifies `featurewise_std_normalization`, but it hasn't been fit on any training data."

但我没有找到有关如何将train_dataget.fit()flow_from_directory一起使用的明确信息。

感谢您的帮助。 马里奥

1 个答案:

答案 0 :(得分:12)

你是对的,docs对此并不是很有启发性......

您需要的实际上是一个分为四步的过程:

  1. 定义数据扩充
  2. 适合扩充
  3. 使用flow_from_directory()
  4. 设置您的生成器
  5. 使用fit_generator()
  6. 训练您的模型

    以下是假设图像分类案例的必要代码:

    # define data augmentation configuration
    train_datagen = ImageDataGenerator(featurewise_center=True,
                                       featurewise_std_normalization=True,
                                       zca_whitening=True)
    
    # fit the data augmentation
    train_datagen.fit(x_train)
    
    # setup generator
    train_generator = train_datagen.flow_from_directory(
            train_data_dir,
            target_size=(img_height, img_width),
            batch_size=batch_size,
            class_mode='categorical')
    
    # train model
    model.fit_generator(
        train_generator,
        steps_per_epoch=nb_train_samples,
        epochs=epochs,
        validation_data=validation_generator, # optional - if used needs to be defined
        validation_steps=nb_validation_samples) 
    

    显然,有几个参数需要定义(train_data_dirnb_train_samples等),但希望你明白这一点。

    如果您还需要使用validation_generator,就像在我的示例中一样,这应该与您的train_generator一样定义。

    更新(评论后)

    第2步需要一些讨论;在这里,x_train是理想情况下应该适合主存储器的实际数据。另外(documentation),这一步是

      

    仅在featurewise_center或featurewise_std_normalization或zca_whitening时才需要。

    然而,在许多现实世界中,所有训练数据都适合记忆的要求显然是不现实的。在这种情况下,如何集中/规范化/白化数据本身就是一个(巨大的)子领域,可以说是存在大数据处理框架(如Spark)的主要原因。

    那么,在这里做什么呢?那么,在这种情况下,下一个合乎逻辑的行动是示例您的数据;事实上,这正是社区所建议的 - 这里是Keras创作者Francois Chollet Working with large datasets like Imagenet

    datagen.fit(X_sample) # let's say X_sample is a small-ish but statistically representative sample of your data
    

    ongoing open discussion关于延长ImageDataGenerator的另一个引用(强调补充):

      

    功能是特征标准化和ZCA所必需的,它只需要一个数组作为参数,不适合目录。 目前,我们需要手动读取图像的一个子集,以便适合目录。一个想法是我们可以更改fit()以接受生成器本身(flow_from_directory),当然,标准化应该在适合期间禁用。

    希望这会有所帮助......