利用K折交叉验证平衡不平衡的数据集

时间:2019-08-21 16:26:00

标签: python deep-learning pytorch data-augmentation

我正在尝试使用Pytorch对不平衡的图像数据集(1:250类图像,0类:4000ish图像)进行训练/验证CNN,现在,我仅在训练集上尝试了增强(谢谢) @jodag)。但是,我的模型仍在学习以明显更多的图像来吸引全班同学。

我想找到补偿我不平衡数据集的方法。

我曾考虑过使用不平衡数据采样器(https://github.com/ufoym/imbalanced-dataset-sampler)使用过采样/欠采样,但是我已经使用采样器来选择用于5折验证的索引。有没有一种方法可以使用下面的代码实现交叉验证并添加此采样器?同样,是否有一种方法可以比另一个标签更频繁地扩展一个标签?根据这些问题,是否有其他更简便的方法可以解决我尚未研究的不平衡数据集?

这是到目前为止我所拥有的例子

total_set = datasets.ImageFolder(PATH)
KF_splits = KFold(n_splits= 5, shuffle = True, random_state = 42)

for train_idx, valid_idx in KF_splits.split(total_set):
    #sampler to get indices for cross validation
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    #Use a wrapper to apply augmentation only to training set
    #These are dataloaders that pull images from the same folder but sort into validation and training sets
    #Though transforms augment only the training set, it doesn't address
    #the underlying issue of a heavily unbalanced dataset

    train_loader = torch.utils.data.DataLoader(
        WrapperDataset(total_set, transform=data_transforms['train']),
        batch_size=32, sampler=ImbalancedDatasetSampler(total_set))
    valid_loader = torch.utils.data.DataLoader(
        WrapperDataset(total_set, transform=data_transforms['val']),
        batch_size=32)

    print("Fold:" + str(i))

    for epoch in range(epochs):
        #Train/validate model below

`

感谢您的时间和帮助!

0 个答案:

没有答案