如何使用pytorch将数据集拆分为自定义训练集和自定义验证集?

时间:2020-05-05 21:52:11

标签: python machine-learning neural-network pytorch

我正在使用非Torchvision数据集,并已使用ImageFolder方法将其提取。我正在尝试将数据集分成20%的验证集和80%的训练集。我只能从PyTorch库中找到此方法(random_split),该方法允许拆分数据集。但是,每次都是随机的。我想知道是否有一种方法可以在PyTorch库中以特定数量分割数据集?

这是我提取数据集并随机分割的代码。

transformations = transforms.Compose([
    transforms.Resize(255),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

TrafficSignSet = datasets.ImageFolder(root='./train/', transform=transformations)

####### split data
train_size = int(0.8 * len(TrafficSignSet))
test_size = len(TrafficSignSet) - train_size
train_dataset_split, test_dataset_split = torch.utils.data.random_split(TrafficSignSet, [train_size, test_size])

#######put into a Dataloader
train_dataset = torch.utils.data.DataLoader(train_dataset_split, batch_size=32, shuffle=True)
test_dataset = torch.utils.data.DataLoader(test_dataset_split, batch_size=32, shuffle=True)

1 个答案:

答案 0 :(得分:1)

如果您查看random_split的“内幕”,您会发现它使用torch.utils.data.Subset进行实际拆分。您可以使用固定索引自己这样做:

import random

indices = list(range(len(TrafficSignSet))
random.seed(310)  # fix the seed so the shuffle will be the same everytime
random.shuffle(indices)
train_dataset_split = torch.utils.data.Subset(TrafficSignSet, indices[:train_size])
val_dataset_split = torch.utils.data.Subset(TrafficSignSet, indices[train_size:])