Keras:ImageDataGenerator性能不佳

时间:2019-01-22 07:09:28

标签: python image keras computer-vision data-augmentation

我尝试使用Keras ImageDataGenerator扩充图像数据。我的任务是回归任务,其中输入图像会生成另一个变换后的图像。到目前为止,效果很好。

在这里,我想使用ImageDataGenerator来应用数据扩充。为了以相同的方式变换两个图像,我使用了Keras docs中描述的方法,其中描述了具有相应遮罩的图像变换。我的情况有些不同,因为我的图像已经加载,不需要从目录中提取。 another StackOverlow post中已经描述了此过程。

要验证我的实现,我首先使用了它而未进行扩充,并且使用了ImageDataGenerator而不指定任何参数。根据{{​​3}}中的类引用,这不应更改图像。请参阅以下代码段:

img_val = img[0:split_seperator]
img_train = img[split_seperator:]

target_val = target[0:split_seperator]
target_train = target[split_seperator:]

data_gen_args = dict()

# define data preparation
src_datagen = ImageDataGenerator(**data_gen_args)
target_datagen = ImageDataGenerator(**data_gen_args)

# fit parameters from data
seed = 1
src_datagen.fit(img_train, augment=False, seed=seed)
target_datagen.fit(target_train, augment=False, seed=seed)

training_generator = zip(
    src_datagen.flow(img_train, batch_size=batch_size_training, seed=seed),
    target_datagen.flow(target_train, batch_size=batch_size_training, seed=seed))

_ = model.fit_generator(
    generator=training_generator,
    steps_per_epoch=image_train.shape[0] // batch_size_training,
    epochs=num_epochs, verbose=1,
    validation_data=(img_val, target_val), callbacks=callbacks)

不幸的是,我的实现似乎有一些问题。我没有得到预期的表演。验证损失在某种程度上稳定在某个值附近,并且仅略有下降(请参见下图)。我希望在这里,因为我没有使用任何增强,所以损失与未增强的基线相同。

Keras docs

相比之下,没有ImageDataGenerator的训练看起来像

_ = model.fit(img, target,
              batch_size=batch_size_training,
              epochs=num_epochs, verbose=1,
              validation_split=0.2, callbacks=cb)

我想我在某种程度上与ImageDataGeneratorflowfit函数的用法混为一谈。所以我的问题是:

  • 所应用的功能fitflow之一是否多余并导致此行为?
  • 我有实施上的问题吗?
  • 这种实现通常有意义吗?
  • 设置验证集修订是否有意义,还是应该增加验证集?

更新(2019年1月23日及以后): 到目前为止,我已经尝试过(响应答案):

  • 也为验证数据创建一个生成器
  • 删除应用的拟合函数
  • 在流函数中设置shuffle=True(数据已经被重新整理)

这两种方法都不能使结果更好。

1 个答案:

答案 0 :(得分:1)

最后,我了解您要做什么,这应该可以完成工作。

aug = ImageDataGenerator(**data_gen_args)

# train the network
H = model.fit_generator(aug.flow(img_train, target_train, batch_size=image_train.shape[0]),
    validation_data=(img_val, target_val), steps_per_epoch=image_train.shape[0] // BS,
    epochs=EPOCHS)

让我知道这是否可行。