Keras:如何扩展validation_split以生成第三组(即测试集)?

时间:2018-08-21 15:38:58

标签: python tensorflow keras

我正在将Keras与TensorFlow后端一起使用。我正在将ImageDataGenerator与validate_split参数一起使用,以将数据拆分为训练集和验证集。因此,我将flow_from_directory的子集设置为“ training”和“ testing”,如下所示:

total_gen = ImageDataGenerator(validation_split=0.3)


train_gen = data_generator.flow_from_directory(my_dir, target_size=(input_size, input_size), shuffle=False, seed=13,
                                                     class_mode='categorical', batch_size=BATCH_SIZE, subset="training")

valid_gen = data_generator.flow_from_directory(my_dir, target_size=(input_size, input_size), shuffle=False, seed=13,
                                                     class_mode='categorical', batch_size=32, subset="validation")

这非常方便,因为它允许我仅使用一个目录而不是两个目录(一个用于培训,一个用于验证)。现在,我想知道是否有可能扩展此过程以生成第三组即测试集?

1 个答案:

答案 0 :(得分:1)

这是不可能的。您应该可以对ImageDataGenerator的{​​{3}}进行一些小的修改来做到这一点:

if subset is not None:
    if subset not in {'training', 'validation'}: # add a third subset here
        raise ValueError('Invalid subset name:', subset,
                         '; expected "training" or "validation".') # adjust message
    split_idx = int(len(x) * image_data_generator._validation_split) 
    # you'll need two split indices here
    if subset == 'validation':
        x = x[:split_idx]
        x_misc = [np.asarray(xx[:split_idx]) for xx in x_misc]
        if y is not None:
            y = y[:split_idx]
    elif subset == '...' # add extra case here

    else:
        x = x[split_idx:]
        x_misc = [np.asarray(xx[split_idx:]) for xx in x_misc] # change slicing
        if y is not None:
            y = y[split_idx:] # change slicing

编辑:这是修改代码的方法:

if subset is not None:
    if subset not in {'training', 'validation', 'test'}:
        raise ValueError('Invalid subset name:', subset,
                         '; expected "training" or "validation" or "test".')
    split_idxs = (int(len(x) * v) for v in image_data_generator._validation_split)
    if subset == 'validation':
        x = x[:split_idxs[0]]
        x_misc = [np.asarray(xx[:split_idxs[0]]) for xx in x_misc]
        if y is not None:
            y = y[:split_idxs[0]]
    elif subset == 'test':
        x = x[split_idxs[0]:split_idxs[1]]
        x_misc = [np.asarray(xx[split_idxs[0]:split_idxs[1]]) for xx in x_misc]
        if y is not None:
            y = y[split_idxs[0]:split_idxs[1]]
    else:
        x = x[split_idxs[1]:]
        x_misc = [np.asarray(xx[split_idxs[1]:]) for xx in x_misc]
        if y is not None:
            y = y[split_idxs[1]:]

基本上,validation_split现在应该是两个浮点数而不是单个浮点数的元组。验证数据将是0和validation_split[0]之间的数据,validation_split[0] and validation_split[1]之间的测试数据和validation_split[1]和1之间的训练数据的一部分。这是您可以使用的方式:

import keras
# keras_custom_preprocessing is how i named my directory
from keras_custom_preprocessing.image import ImageDataGenerator

generator = ImageDataGenerator(validation_split=(0.1, 0.5))
# First 10%: validation data - next 40% test data - rest: training data        
gen = generator.flow_from_directory(directory='./data/', subset='test')
# Finds 40% of the images in the dir

您将需要在其他两到三行中修改文件(必须进行类型检查),仅此而已。我有修改过的文件,如果您有兴趣,请告诉我,我可以将其托管在我的github上。