将预处理功能应用于keras ImageDataGenerator

时间:2020-04-15 16:53:48

标签: image-processing keras deep-learning tensorflow2.0 image-segmentation

我正在使用keras ImageDataGenerator来扩充我的数据。但是我想在增强图像上应用其他自定义转换。我知道ImageDataGenerator需要一个preprocessing_function来做到这一点,除了我的其他转换需要将图像和地面真实情况作为输入,而preprocessing_function仅需要将一个图像作为输入。

我已经以非常繁琐的方式实现了这一点(下面的代码),我的问题是是否有更好的方法来做到这一点。 作为附加的变换,我对蒙版(基本事实)进行了阈值处理,并应用了将某些参数作为输入以及图像与蒙版一起使用的增强函数。

    image_datagen = kp.image.ImageDataGenerator(**data_gen_args_)
    mask_datagen = kp.image.ImageDataGenerator(**data_gen_args_)
    image_val_datagen = kp.image.ImageDataGenerator(**data_gen_args_)
    mask_val_datagen = kp.image.ImageDataGenerator(**data_gen_args_)

    image_generator = image_datagen.flow(x_train, seed=seed)
    mask_generator = mask_datagen.flow(y_train, seed=seed)
    image_val_generator = image_val_datagen.flow(x_val, seed=seed + 1)
    mask_val_generator = mask_val_datagen.flow(y_val, seed=seed + 1)

    imgs = [next(image_generator) for _ in range(1000)]
    masks = [np.where(next(mask_generator) > 0.5, 1, 0).astype('float32') for _ in range(1000)]   #because keras datagumentation interpolates the data, a threshold must be taken to make the data binary again
    imgs_val = [next(image_val_generator) for _ in range(1000)]
    masks_val = [np.where(next(mask_val_generator) > 0.5, 1, 0).astype('float32') for _ in range(1000)]

    imgs = np.concatenate(imgs)
    masks = np.concatenate(masks)
    imgs_val = np.concatenate(imgs_val)
    masks_val = np.concatenate(masks_val)

    for i in range(imgs.shape[0]):
        imgs[i] = augment(imgs[i], masks[i], brightness_range = data_gen_args['brightness_range'], noise_var_range = data_gen_args['noise_var_range'], bias_var_range = data_gen_args['bias_var_range'])
    for i in range(imgs_val.shape[0]):
        imgs_val[i] = augment(imgs_val[i], masks_val[i], brightness_range = data_gen_args['brightness_range'], noise_var_range = data_gen_args['noise_var_range'], bias_var_range = data_gen_args['bias_var_range'])

    train_dataset = tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices(imgs), tf.data.Dataset.from_tensor_slices(masks)))
    train_dataset = train_dataset.repeat().shuffle(1000).batch(32)
    validation_set = tf.data.Dataset.zip((tf.data.Dataset.from_tensor_slices(imgs_val), tf.data.Dataset.from_tensor_slices(masks_val)))
    validation_set = validation_set.repeat().shuffle(1000).batch(32)

0 个答案:

没有答案