我正在尝试使用来自Keras的ImageDataGenerator和.flow(),以便为CNN分段任务实时增强图像数据。我对此类编程的基本理解提前道歉,因为我对它很陌生。我有一个功能如下,旨在增加图像数据和相应的数据掩码:
def image_augmentation(imgs, masks):
# create two instances with the same arguments
# create dictionary with the input augmentation values
data_gen_args = dict(featurewise_center=False,
featurewise_std_normalization=False,
rotation_range=90.,
width_shift_range=0.1,
height_shift_range=0.1,
zoom_range=0.2,
horizontal_flip=True,
vertical_flip = True)
## use this method with both images and masks
image_datagen = ImageDataGenerator(**data_gen_args)
mask_datagen = ImageDataGenerator(**data_gen_args)
# Provide the same seed and keyword arguments to the fit and flow methods
seed = 1
## fit the augmentation model to the images and masks with the same seed
image_datagen.fit(imgs, augment=True, seed=seed)
mask_datagen.fit(masks, augment=True, seed=seed)
## set the parameters for the data to come from (images)
image_generator = image_datagen.flow(
imgs,
batch_size=32,
shuffle=True,
seed=seed)
## set the parameters for the data to come from (masks)
mask_generator = mask_datagen.flow(
masks,
batch_size=32,
shuffle=True,
seed=seed)
# combine generators into one which yields image and masks
train_generator = zip(image_generator, mask_generator)
## return the train generator for input in the CNN
return train_generator
根据我的理解,火车发电机应该由32个增强图像和相应的掩模组成。但是,当我跑:
train_generator = image_augmentation(imgs_train, imgs_spine_mask_train)
## run the fit generator CNN
model.fit_generator(train_generator,steps_per_epoch=100,epochs=5000, callbacks=[model_checkpoint])
image_augmentation只会永远运行..没有错误
imgs_train,imgs_spine_mask_train都是4d数组
有谁能让我知道我在这里做错了什么?值得注意的是,我完全是stackoverflow的新手,所以如果我可以更改格式等,请告诉我。
干杯,