Keras:使用带有FCNN的flow_from_directory

时间:2018-01-22 08:23:07

标签: python keras unet

我成功地训练了一个UNET NN和Keras。现在我想尝试在我的图像中应用一些数据增强来达到更好的性能。为此,我使用ImageDataGenerator然后flow_from_directory仅将批次加载到内存中(我尝试了但没有内存错误)。代码示例是:

training_images = np.array(training_images) 
training_masks = np.array(training_masks)[:, :, :, 0].reshape(len(training_masks), 400, 400, 1)

# generators for data augmentation -------
seed = 1
generator_x = ImageDataGenerator(
    featurewise_center=True,
    featurewise_std_normalization=True,
    rotation_range=180,
    horizontal_flip=True,
    fill_mode='reflect')

generator_y = ImageDataGenerator(
    featurewise_center=False,
    featurewise_std_normalization=False,
    rotation_range=180,
    horizontal_flip=True,
    fill_mode='reflect')

generator_x.fit(training_images, augment=True, seed=seed)
generator_y.fit(training_masks, augment=True, seed=seed)


image_generator = generator_x.flow_from_directory(
    'data',
    target_size=(400, 400),
    class_mode=None,
    seed=seed)

mask_generator = generator_y.flow_from_directory(
    'masks',
    target_size=(400, 400),
    class_mode=None,
    seed=seed)

train_generator = zip(image_generator, mask_generator)
model = unet(img_rows, img_cols)
model.fit_generator(train_generator, steps_per_epoch=int(len(training_images)/4), epochs=1)

然而,当我运行代码时,我收到以下错误(我正在使用Tensorflow后端):

InvalidArgumentError (see above for traceback): Incompatible shapes: [14400000] vs. [4800000]
     [[Node: loss/out_loss/mul = Mul[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](loss/out_loss/Reshape, loss/out_loss/Reshape_1)]]

在错误中它抱怨不兼容的形状14400000(400x400x9)与4800000(400x400x3)。我在这里使用自定义损失函数(如果你看错误它说的是关于损失的东西)那就是Dice系数,定义如下:

y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection + 1.) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1.)

这里我使用(400,400,3)带有掩模的图像,用于1种形状(400,400,1)。我的UNET输入定义为Input((img_rows, img_cols, 3)),输出为Conv2D(1, (1, 1), activation='sigmoid', name='out')(conv9)(但在没有数据增强的情况下进行训练时工作正常)。

1 个答案:

答案 0 :(得分:1)

发生错误是因为您正在以RGB颜色模式读取遮罩。

color_mode中的默认flow_from_directory'rgb'。因此,如果不指定color_mode,则会将掩码加载到(batch_size, 400, 400, 3)数组中。这就是为什么y_true_f比错误消息中的y_pred_f大3倍。

要以灰度读取遮罩,请使用color_mode='grayscale'

mask_generator = generator_y.flow_from_directory(
    'masks',
    target_size=(400, 400),
    class_mode=None,
    color_mode='grayscale',
    seed=seed)