我是Keras和TF的新手,但是我有一个独特的问题。我正在训练一个分段网络,该网络将从单独的目录中读取我的训练数据和掩码。我现在正在训练RGBA图像,但是下一步,我需要堆叠其中的一些图像,这将导致Pillow无法处理的10多个通道。另外,我的数据太大而无法容纳在内存中,因此我将需要flow_from_directory功能。
我将以下内容用于RGBA训练样本和RGB蒙版。
def trainGenerator(batch_size,train_path,image_folder,mask_folder,aug_dict,image_color_mode = "rgba",
mask_color_mode = "grayscale",image_save_prefix = "image",mask_save_prefix = "mask",
flag_multi_class = False,num_class = 2,save_to_dir = None,target_size = (256,256),seed = 1):
image_datagen = ImageDataGenerator(**aug_dict)
mask_datagen = ImageDataGenerator(**aug_dict)
image_generator = image_datagen.flow_from_directory(
train_path,
classes = [image_folder],
class_mode = None,
color_mode = image_color_mode,
target_size = target_size,
batch_size = batch_size,
save_to_dir = save_to_dir,
save_prefix = image_save_prefix,
seed = seed)
mask_generator = mask_datagen.flow_from_directory(
train_path,
classes = [mask_folder],
class_mode = None,
color_mode = mask_color_mode,
target_size = target_size,
batch_size = batch_size,
save_to_dir = save_to_dir,
save_prefix = mask_save_prefix,
seed = seed)
train_generator = zip(image_generator, mask_generator)
for (img,mask) in train_generator:
img,mask = adjustData(img,mask,flag_multi_class,num_class)
yield (img,mask)
adjustData函数仅对蒙版进行一键式编码-因此不相关。
但是现在我正在尝试以下方法:
def npyTrainGenerator(image_path,mask_path,flag_multi_class = False,num_class = 2,image_prefix = "image",mask_prefix = "mask",image_as_gray = False,mask_as_gray = True):
image_name_arr = glob.glob(os.path.join(image_path,"%s*.npy"%image_prefix))
image_arr = []
mask_arr = []
for index,item in enumerate(image_name_arr):
img = np.load(item)
mask = io.imread(item.replace(image_path,mask_path).replace(image_prefix,mask_prefix),as_gray = mask_as_gray)
img,mask = adjustData(img,mask,flag_multi_class,num_class)
image_arr.append(img)
mask_arr.append(mask)
image_arr = np.array(image_arr)
mask_arr = np.array(mask_arr)
return image_arr,mask_arr
问题有两件事:我无法分配批次。对于模型而言,每批获取一个(数据,掩码)不是最佳选择。同样,上面的函数将组成一个巨大的数组并将所有内容存储在内存中。我想每批次获取n *(256、256、10),其中n是批次大小,而256和256是我的训练数据的宽度和高度,而10是通道数-类似于flow_from_directory的功能。此外,我需要一种类似于“种子”的行为来基于蒙版将它们对齐。
我的面具是常规的RGB,所以在那里没有问题。
我应该指出,输入数据的大小是固定的,不需要任何扩充。
有没有一种方法可以实现此目的而无需编辑Keras的实际ImageDataGenerator和Iterator类,因为在S.O上有一些答案。指向那个?
任何指导将不胜感激。谢谢。