keras ImageDataGenerator内插二进制掩码

时间:2020-03-04 08:21:53

标签: python keras tensorflow2.0 data-augmentation

我正在训练一个神经网络,以预测鼠标大脑图像上的二进制掩码。为此,我使用来自keras的ImageDataGenerator扩展了数据。

但是我已经意识到,当应用空间变换时,数据生成器正在对数据进行插值。

这对图像很好,但是我当然不希望我的遮罩包含非二进制值。

在应用变换时是否可以选择类似最近邻插值的方法?我在keras文档中没有找到这样的选项。

To the left is the original binary mask, to the right is the augmented, interpolated mask

(左边是原始二进制掩码,右边是增强的内插掩码)

图片代码:

data_gen_args = dict(rotation_range=90,
                     width_shift_range=30,
                     height_shift_range=30,
                     shear_range=5,
                     zoom_range=0.3,
                     horizontal_flip=True,
                     vertical_flip=True,
                     fill_mode='nearest')
image_datagen = kp.image.ImageDataGenerator(**data_gen_args)
image_generator = image_datagen.flow(image, seed=1)
plt.figure()
plt.subplot(1,2,1)
plt.imshow(np.squeeze(image))
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(np.squeeze(image_generator.next()[0]))
plt.axis('off')
plt.savefig('vis/keras_example')

1 个答案:

答案 0 :(得分:1)

我自己的二进制图像数据有相同的问题。有多种方法可以解决此问题。

简单答案:我通过将ImageDataGenerator的结果手动转换为二进制来解决了这个问题。如果要手动遍历生成器(使用“ next()”方法或“ for”循环),则只需使用numpy的“ where”方法即可将非二进制值转换为二进制值:

import numpy as np

batch = image_generator.next()
binary_images = np.where(batch>0, 1, 0)  ## or batch>0.5 or any other thresholds

在ImageDataGenerator中使用preprocessing_function自变量

另一种更好的方法是在preprocessing_function中使用ImageDataGenerator参数。如documentation中所述,可以指定一个自定义的预处理函数,该函数将在数据扩充过程之后执行,因此您可以在data_gen_args中指定此函数,如下所示:

from keras.preprocessing.image import ImageDataGenerator

data_gen_args = dict(rotation_range=90,
                     width_shift_range=30,
                     height_shift_range=30,
                     shear_range=5,
                     zoom_range=0.3,
                     horizontal_flip=True,
                     vertical_flip=True,
                     fill_mode='nearest',
                     preprocessing_function = lambda x: np.where(x>0, 1, 0).astype(x.dtype))

注意:根据我的经验,preprocessing_function是在rescale之前执行的,也可以将ImageDataGenerator指定为{ {1}}。这不是您的情况,但是如果您需要指定该参数,请记住这一点。

创建自定义生成器

另一种解决方案是编写自定义数据生成器,并在其中修改ImageDataGenerator的输出。然后,使用此新生成器来填充data_gen_args。像这样:

model.fit()

上面的数据生成器也是一个简单的数据生成器。如果需要,您可以自定义它并添加标签(例如this)或多峰数据等。