我正在尝试为不平衡数据集(类别0 = 4000张图像,类别1 =大约250张图像)创建一个二进制CNN分类器,我想对其执行5倍交叉验证。当前,我正在将训练集加载到ImageLoader中,该图像集将应用转换/增强(?)并将其加载到DataLoader中。但是,这会导致我的训练拆分和验证拆分都包含增强数据。
我最初是离线应用转换(离线扩充?)来平衡我的数据集,但是从这个线程(https://stats.stackexchange.com/questions/175504/how-to-do-data-augmentation-and-train-validate-split)看来,仅扩充训练集似乎是理想的。我也更愿意在仅扩充的训练数据上训练我的模型,然后通过5倍交叉验证在非扩充的数据上对其进行验证
我的数据按根/标签/图像进行组织,其中有2个标签文件夹(0和1),图像分别分类到各个标签中。
total_set = datasets.ImageFolder(ROOT, transform = data_transforms['my_transforms'])
//Eventually I plan to run cross-validation as such:
splits = KFold(cv = 5, shuffle = True, random_state = 42)
for train_idx, valid_idx in splits.split(total_set):
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(total_set, batch_size=32, sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(total_set, batch_size=32, sampler=valid_sampler)
model.train()
//Model train/eval works but may be overpredict
我确定我在这段代码中做的不是最理想的,或者是错误的,但是我似乎找不到任何文档来专门增强交叉验证中的训练内容!
任何帮助将不胜感激!
答案 0 :(得分:0)
一种方法是实现包装器Dataset类,该类将转换应用于ImageFolder数据集的输出。例如
class WrapperDataset:
def __init__(self, dataset, transform=None, target_transform=None):
self.dataset = dataset
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
image, label = self.dataset[index]
if self.transform is not None:
image = self.transform(image)
if self.target_transform is not None:
label = self.target_transform(label)
return image, label
def __len__(self):
return len(self.dataset)
然后,您可以在代码中使用不同的转换包装较大的数据集。
total_set = datasets.ImageFolder(ROOT)
//Eventually I plan to run cross-validation as such:
splits = KFold(cv = 5, shuffle = True, random_state = 42)
for train_idx, valid_idx in splits.split(total_set):
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
train_loader = torch.utils.data.DataLoader(
WrapperDataset(total_set, transform=data_transforms['train_transforms']),
batch_size=32, sampler=train_sampler)
valid_loader = torch.utils.data.DataLoader(
WrapperDataset(total_set, transform=data_transforms['valid_transforms']),
batch_size=32, sampler=valid_sampler)
# train/validate now
由于我没有完整的代码/模型,因此我没有测试过此代码,但是概念应该很清楚。