如何统一分配火车组

时间:2018-04-22 08:36:17

标签: python machine-learning keras

我有以下目录结构

data/
    train/
        Cat 1/ ### 5000 pictures
            dog001.jpg

            ...
        cat 2/ ### 3000 pictures
            cat001.jpg

       Cat 3/ ### 50000 pictures
            Unicorn.jpg

            ...
        Cat 4/ ### 10000 pictures
            Angels.jpg

我使用以下代码加载我的图片

datagen = ImageDataGenerator(rescale=1./255)

# automagically retrieve images and their classes for train and validation sets
train_generator = datagen.flow_from_directory(
        train_data_dir,
        target_size=(img_width, img_height),
        batch_size=batch_size,
        class_mode="categorical")

由于我的数据分布不均匀,所以我的模型不合适,它变得偏向Cat 3,所以我如何加载一个统一的列车数据?

2 个答案:

答案 0 :(得分:0)

你有两种方法:

  1. cat3删除一些数据,以便数据可以统一改组
  2. 将数据添加到其他类
  3. 1非常简单,要添加数据,您可以从其他不太频繁的类中复制数据,或者更好的方法是从现有的

    生成新数据

    通过操纵图像,你可以将一行/列设置为空白,你可以旋转图像或移动它,我用这样的smth来实现那些效果一个28x28的图像

    import numpy as np
    from scipy.ndimage.interpolation import rotate, shift
    
    def rand_jitter(temp, prob=0.5):
        np.random.seed(1337)  # for reproducibility
        if np.random.random() > prob:
            temp[np.random.randint(0,28,1), :] = 0
        if np.random.random() > prob:
            temp[:, np.random.randint(0,28,1)] = 0
        if np.random.random() > prob:
            temp = shift(temp, shift=(np.random.randint(-3,4,2)))
        if np.random.random() > prob:
            temp = rotate(temp, angle = np.random.randint(-20,21,1), reshape=False)
        return temp
    

    通过这种方式,您可以使用更多数据训练您的网络,并将其推广并使其预测最稳健

答案 1 :(得分:0)

您不必删除任何数据点,并且应该保留尽可能多的数据点。

为此,您需要向现有的keras图像数据生成器添加一些代码,但它应该很简单。这里的一般想法是提供一个自定义采样函数,根据目标类统一采样训练数据点,您可以分三步完成:

  1. 建立字典LUT = {' class-1' :[class-1 files],' class-2' :[class-2 files],...,' class-k':[class-k files]}

  2. 以均匀随机的方式在LUT中选择一个键

  3. LUT[key]以统一随机的方式选择一个文件