假设我正在使用以下电话:
trainset = torchvision.datasets.ImageFolder(root="imgs/", transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,suffle=True,num_workers=1)
据我所知,这将火车组定义为文件夹中的所有图像"图像",标签由特定文件夹位置定义。
我的问题是 - 是否有任何直接/简单的方法将列车集定义为此文件夹中图像的子样本?例如,将trainset定义为来自每个子文件夹的10个图像的随机样本?
提前致谢
答案 0 :(得分:1)
您可以将类DatasetFolder
(或ImageFolder)包装在另一个类中以限制数据集:
class LimitDataset(data.Dataset):
def __init__(self, dataset, n):
self.dataset = dataset
self.n = n
def __len__(self):
return self.n
def __getitem__(self, i):
return self.dataset[i]
您还可以在LimitDataset
中的索引和原始数据集中的索引之间定义一些映射,以定义更复杂的行为(例如随机子集)。
如果要限制每个纪元的批次而不是数据集大小:
from itertools import islice
for data in islice(dataloader, 0, batches_per_epoch):
...
请注意,如果您使用此shuffle,数据集大小将相同,但每个纪元将看到的数据将受到限制。如果您不对数据集进行随机播放,这也会限制数据集的大小。