我正在使用非Torchvision数据集,并已使用ImageFolder方法将其提取。我正在尝试将数据集分成20%的验证集和80%的训练集。我只能从PyTorch库中找到此方法(random_split),该方法允许拆分数据集。但是,每次都是随机的。我想知道是否有一种方法可以在PyTorch库中以特定数量分割数据集?
这是我提取数据集并随机分割的代码。
transformations = transforms.Compose([
transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
TrafficSignSet = datasets.ImageFolder(root='./train/', transform=transformations)
####### split data
train_size = int(0.8 * len(TrafficSignSet))
test_size = len(TrafficSignSet) - train_size
train_dataset_split, test_dataset_split = torch.utils.data.random_split(TrafficSignSet, [train_size, test_size])
#######put into a Dataloader
train_dataset = torch.utils.data.DataLoader(train_dataset_split, batch_size=32, shuffle=True)
test_dataset = torch.utils.data.DataLoader(test_dataset_split, batch_size=32, shuffle=True)
答案 0 :(得分:1)
如果您查看random_split
的“内幕”,您会发现它使用torch.utils.data.Subset
进行实际拆分。您可以使用固定索引自己这样做:
import random
indices = list(range(len(TrafficSignSet))
random.seed(310) # fix the seed so the shuffle will be the same everytime
random.shuffle(indices)
train_dataset_split = torch.utils.data.Subset(TrafficSignSet, indices[:train_size])
val_dataset_split = torch.utils.data.Subset(TrafficSignSet, indices[train_size:])