如何使用自定义数据生成器进行keras图像增强?

时间:2020-07-30 18:02:45

标签: python tensorflow keras tensorflow2.0

我正在使用Keras自定义生成器,我想对自定义数据生成器返回的数据应用图像增强技术。

我想要这些图像增强技术

ImageDataGenerator(
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest')

这是keras定制生成器

def __data_generation(self, list_IDs_temp):
  'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
  # Initialization
  X = np.empty((self.batch_size, *self.dim, self.n_channels))
  y = np.empty((self.batch_size), dtype=int)

      # Generate data
      for i, ID in enumerate(list_IDs_temp):
          # Store sample
          X[i,] = tfk.preprocessing.image.load_img(self.list_IDs[ID])
    
          # Store class
          y[i] = self.labels[ID]
    
      return X, tkf.utils.to_categorical(y, num_classes=self.n_classes)

1 个答案:

答案 0 :(得分:1)

尚未尝试过,但我想您可以使用 flow 实例中的 ImageDataGenerator 方法。例如,您的自定义类可能如下所示:

class CustomDataGenerator(tf.keras.utils.Sequence):
    
    def __init__(self, batch_size=32):
        self.batch_size = batch_size
        self.augmentor = ImageDataGenerator(
            rotation_range=40,
            width_shift_range=0.2,
            height_shift_range=0.2,
            shear_range=0.2,
            zoom_range=0.2,
            horizontal_flip=True,
            fill_mode='nearest'
        )

    ...

    def __data_generation(self, list_IDs_temp):
      'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
      # Initialization
      X = np.empty((self.batch_size, *self.dim, self.n_channels))
      y = np.empty((self.batch_size), dtype=int)

      # Generate data
      for i, ID in enumerate(list_IDs_temp):
          # Store sample
          X[i,] = tfk.preprocessing.image.load_img(self.list_IDs[ID])
    
          # Store class
          y[i] = self.labels[ID]

      X_gen = self.augmentor.flow(X, batch_size=self.batch_size, shuffle=False)
      """do not perform shuffle here, the shuffling is performed beforehand
       by your custom class anyway, you just want the transformations to be 
      applied, and above all you want to keep your images synced with the 
      labels.""" 
      
      return next(X_gen), tkf.utils.to_categorical(y, num_classes=self.n_classes)