Keras ImageDataGenerator:如何在图像路径中使用数据增强

时间:2020-06-22 16:16:54

标签: python tensorflow keras deep-learning data-augmentation

我正在研究CNN模型,我想使用一些数据扩充功能,但是会出现两个问题:

  1. 我的标签是图像(我的模型是某种自动编码器,但是预期的输出图像与输入的图像不同),因此无法使用诸如ImageDataGenerator.flow_from_directory()之类的功能。我当时在想ImageDataGenerator.flow(train_list, y = labels_list),但是第二个问题是我来的:
  2. 我的输入和标签数据集都非常大,我更喜欢使用图像路径flow()函数未正确处理),而不是将所有数据集加载到单个阵列并使我的RAM爆炸。

如何正确处理这两个问题?对于我发现的情况,可能有两种解决方案:

  1. 创建我自己的生成器:我听说过__getitem__类中的Keras Sequence函数,但这会影响ImageDataGenerator类吗? / li>
  2. 使用TF DATA或TFRecords ,但它们似乎很难使用,并且仍然需要实施数据增强。

有没有最简单的方法来克服这个简单的问题?一个简单的技巧就是强制ImageDataGenerator.flow()使用nparray的图像路径,而不是nparray的图像,但是我担心修改Keras / tensorflow文件会产生意想不到的后果(因为一些函数在其他类中被调用,局部更改可能很快导致我所有笔记本库中的全局更改。

1 个答案:

答案 0 :(得分:2)

好吧,感谢this article,我终于找到了解决这些问题的方法。我的错误是尽管ImageDataGenerator缺乏灵活性,但我仍然使用它,因此解决方案很简单:使用另一个数据增强工具。

我们可以按以下方式恢复作者的方法:

  1. 首先,创建个性化批处理生成器作为Keras Sequence类的子类(这意味着要实现__getitem__函数,该函数根据图像的各自路径加载图像) 。
  2. 使用数据扩充albumentations。它的优点是提供更多的转换功能,例如ImgaugImageDataGenerator,同时速度更快。此外,this website允许您使用自己的图像测试其某些增强方法!有关详尽列表,请参见this one

该库的缺点是,由于它相对较新,因此在网上找不到很少的文档,而且我花了几个小时来尝试解决遇到的问题。

实际上,当我尝试可视化一些增强函数时,结果是全黑图像(奇怪的事实:只有当我使用RandomGamma或{{1 }}。使用RandomBrightnessContrastHorizontalFlip之类的转换函数,它将正常工作。

在整整半天的尝试中发现了什么错误之后,我最终想出了这个解决方案,如果您尝试使用该库,可能会为您提供帮助:必须完成图像的加载OpenCV (我正在使用ShiftScaleRotate中的load_imgimg_to_array函数进行加载和处理)。如果有人能解释为什么它不起作用,我将很高兴听到它。

无论如何,这是我显示增强图像的最终代码:

tf.keras.preprocessing.image

编辑:

我在!pip install -U git+https://github.com/albu/albumentations > /dev/null && echo "All libraries are successfully installed!" from albumentations import Compose, HorizontalFlip, RandomBrightnessContrast, ToFloat, RGBShift import cv2 import matplotlib.pyplot as plt import numpy as np from google.colab.patches import cv2_imshow # I work on a Google Colab, thus I cannot use cv2.imshow() augmentation = Compose([HorizontalFlip(p = 0.5), RandomBrightnessContrast(p = 1), ToFloat(max_value = 255) # Normalize the pixels values into the [0,1] interval # Feel free to add more ! ]) img = cv2.imread('Your_path_here.jpg') img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # cv2.imread() loads the images in BGR format, thus you have to convert it to RGB before applying any transformation function. img = augmentation(image = img)['image'] # Apply the augmentation functions to the image. plt.figure(figsize=(7, 7)) plt.imshow((img*255).astype(np.uint8)) # Put the pixels values back to [0,255]. Replace by plt.imshow(img) if the ToFloat function is not used. plt.show() ''' If you want to display using cv2_imshow(), simply replace the last three lines by : img = cv2.normalize(img, None, 255,0, cv2.NORM_MINMAX, cv2.CV_8UC1) # if the ToFloat argument is set up inside Compose(), you have to put the pixels values back to [0,255] before plotting them with cv2_imshow(). I couldn't try with cv2.imshow(), but according to the documentation it seems this line would be useless with this displaying function. cv2_imshow(img) I don't recommend it though, because cv2_imshow() plot the images in BGR format, thus some augmentation methods such as RGBShift will not work properly. ''' 库中遇到了几个问题(我在Github上的this question中进行了介绍,但目前我仍然没有任何答案),因此我最好建议使用albumentations用于数据扩充:它的工作原理很好,并且几乎与Imgaug一样容易使用,即使可用的转换功能少了一点。