我正在使用keras ImageDataGenerator来扩充我的数据。但是我想在增强图像上应用其他自定义转换。我知道ImageDataGenerator需要一个preprocessing_function来做到这一点,除了我的其他转换需要将图像和地面真实情况作为输入,而preprocessing_function仅需要将一个图像作为输入。
我已经以非常繁琐的方式实现了这一点(下面的代码),我的问题是是否有更好的方法来做到这一点。 作为附加的变换,我对蒙版(基本事实)进行了阈值处理,并应用了将某些参数作为输入以及图像与蒙版一起使用的增强函数。
image_datagen = kp.image.ImageDataGenerator(**data_gen_args_)
mask_datagen = kp.image.ImageDataGenerator(**data_gen_args_)
image_val_datagen = kp.image.ImageDataGenerator(**data_gen_args_)
mask_val_datagen = kp.image.ImageDataGenerator(**data_gen_args_)
image_generator = image_datagen.flow(x_train, seed=seed)
mask_generator = mask_datagen.flow(y_train, seed=seed)
image_val_generator = image_val_datagen.flow(x_val, seed=seed + 1)
mask_val_generator = mask_val_datagen.flow(y_val, seed=seed + 1)
imgs = [next(image_generator) for _ in range(1000)]
masks = [np.where(next(mask_generator) > 0.5, 1, 0).astype('float32') for _ in range(1000)] #because keras datagumentation interpolates the data, a threshold must be taken to make the data binary again
imgs_val = [next(image_val_generator) for _ in range(1000)]
masks_val = [np.where(next(mask_val_generator) > 0.5, 1, 0).astype('float32') for _ in range(1000)]
imgs = np.concatenate(imgs)
masks = np.concatenate(masks)
imgs_val = np.concatenate(imgs_val)
masks_val = np.concatenate(masks_val)
for i in range(imgs.shape[0]):
imgs[i] = augment(imgs[i], masks[i], brightness_range = data_gen_args['brightness_range'], noise_var_range = data_gen_args['noise_var_range'], bias_var_range = data_gen_args['bias_var_range'])
for i in range(imgs_val.shape[0]):
imgs_val[i] = augment(imgs_val[i], masks_val[i], brightness_range = data_gen_args['brightness_range'], noise_var_range = data_gen_args['noise_var_range'], bias_var_range = data_gen_args['bias_var_range'])
train_dataset = tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices(imgs), tf.data.Dataset.from_tensor_slices(masks)))
train_dataset = train_dataset.repeat().shuffle(1000).batch(32)
validation_set = tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices(imgs_val), tf.data.Dataset.from_tensor_slices(masks_val)))
validation_set = validation_set.repeat().shuffle(1000).batch(32)