平衡几类不平衡的图像数据集

时间:2021-04-01 20:31:36

标签: python tensorflow keras deep-learning data-augmentation

我有一个在基目录中有 12 个类的数据集。然而,这 12 个类别由若干数量的图像组成。 12 类图像的数量不一致,因此会影响总精度。因此,我是否应该将数据增强应用于数据量较少的特定类?

每个类的图像数据:

template<auto const& header, auto const& description, auto const& image>
constexpr auto markup2() { ... }

int main() {
    static constexpr char const h2[] = "USERS";
    static constexpr char const d2[] = "User Management. Allows creation of administrators and users.";
    static constexpr char const i2[] = "10";

    std::string_view const str1{ markup1() },
                           str2{ markup2<h2, d2, i2>() };
    return str1 == str2; 
}

因此,如果我应用数据增强来增加较低类中的数据量,以及应用数据增强但它不会增加图像数据。除此之外,我想用原始数据生成增强数据,这意味着输入和输出目录将相同。

特定(个别类)的增强代码:

#Dummy Classes

    [AAAA: 713
    ABCD: 274
    ACBD: 335
    ADBC: 576
    BBBB: 538
    BACD: 607
    BCAD: 253
    BDAD: 257
    CCCC: 463
    CABD: 309
    CBAD: 452
    CDAB: 762]

输出:from keras.preprocessing.image import ImageDataGenerator datagen = ImageDataGenerator( rotation_range=45, width_shift_range=0.2, height_shift_range=0.2, shear_range = 0.2, zoom_range = 0.2, horizontal_flip=True, fill_mode = 'reflect', cval = 125) i = 0 for batch in datagen.flow_from_directory(directory = ('/content/dataset/ABCD'), batch_size = 317, target_size = (256, 256), color_mode = ('rgb'), save_to_dir = ('/content/dataset/ABCD'), save_prefix = ('aug'), save_format = ('png')): i += 1 if i > 100: break

1 个答案:

答案 0 :(得分:0)

正如我所提到的,我使用的是 flow_from_dataframe,因此您可以先为您的数据集创建一个 csv 文件,以防您没有。我的想法是将当前数据集重复为每个标签的固定数量的样本,例如,您希望数据集中的每个标签有 762 个样本。这是我使用一些虚拟数据集的方法。

import numpy as np
import pandas as pd
from keras.preprocessing.image import ImageDataGenerator
import cv2

cv2.imwrite('temp.png',np.random.rand(3,3)) # Create a dummy image to be able to use flow_from_dataframe later

labels = [] # Create some unbalanced dataset
for i in range(10):
    labels.append('a')

for i in range(5):
    labels.append('b')

for i in range(3):
    labels.append('c') 

# Create a dataframe
df = pd.DataFrame({'img_path':['./temp.png']*len(labels),'label':labels})

# print(df.head())

def balance_data(df,target_size=12):
    """
    Increase the number of samples to number_of_samples for every label

        Example:
        Current size of the label a: 10
        Target size: 23

        repeat, mod = divmod(target_size,current_size) 
        2, 3 = divmod(23,10)

        Target size: current size * repeat + mod 

    Repeat this example for every label in the dataset.
    """

    df_groups = df.groupby(['label'])
    df_balanced = pd.DataFrame({key:[] for key in df.keys()})

    for i in df_groups.groups.keys():
        df_group = df_groups.get_group(i)
        df_label = df_group.sample(frac=1)
        current_size = len(df_label)

        if current_size >= target_size:
            # If current size is big enough, do nothing
            pass
        else:

            # Repeat the current dataset if it is smaller than target_size 
            repeat, mod = divmod(target_size,current_size)
            

            df_label_new = pd.concat([df_label]*repeat,ignore_index=True,axis=0)
            df_label_remainder = df_group.sample(n=mod)

            df_label_new = pd.concat([df_label_new,df_label_remainder],ignore_index=True,axis=0)

            # print(df_label_new)

        df_balanced = pd.concat([df_balanced,df_label_new],ignore_index=True,axis=0)


    return df_balanced

df_balanced = balance_data(df)
# print(df_balanced)

# A particular image will be transformed to its various versions within the augmentation step 
image_datagen = ImageDataGenerator(
    rotation_range=45,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range = 0.2,
    zoom_range = 0.2, 
    horizontal_flip=True,
    fill_mode = 'reflect', cval = 125)

image_generator = image_datagen.flow_from_dataframe(
            dataframe=df_balanced,
            x_col="img_path",
            y_col="label",
            class_mode="categorical",
            batch_size=4,
            shuffle=True
            )

# x,y=next(image_generator)

我希望代码是不言自明的。如果您需要进一步的帮助,请告诉我。