这是一个代码片段,用于从pytorch transfer learning tutorial加载图像作为数据集:
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
data_dir = 'data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
这是数据集中的示例之一:
image_datasets['val'][0]:
(tensor([[[ 2.2489, 2.2489, 2.2489, ..., 2.2489, 2.2489, 2.2489],
[ 2.2489, 2.2489, 2.2489, ..., 2.2489, 2.2489, 2.2489],
[ 2.2489, 2.2489, 2.2489, ..., 2.2489, 2.2489, 2.2489],
...,
[ 2.2489, 2.2489, 2.2489, ..., 2.2489, 2.2489, 2.2489],
[ 2.2489, 2.2489, 2.2489, ..., 2.2489, 2.2489, 2.2489],
[ 2.2489, 2.2489, 2.2489, ..., 2.2489, 2.2489, 2.2489]],
[[ 2.4286, 2.4286, 2.4286, ..., 2.4286, 2.4286, 2.4286],
[ 2.4286, 2.4286, 2.4286, ..., 2.4286, 2.4286, 2.4286],
[ 2.4286, 2.4286, 2.4286, ..., 2.4286, 2.4286, 2.4286],
...,
[ 2.4286, 2.4286, 2.4286, ..., 2.4286, 2.4286, 2.4286],
[ 2.4286, 2.4286, 2.4286, ..., 2.4286, 2.4286, 2.4286],
[ 2.4286, 2.4286, 2.4286, ..., 2.4286, 2.4286, 2.4286]],
[[ 2.6400, 2.6400, 2.6400, ..., 2.6400, 2.6400, 2.6400],
[ 2.6400, 2.6400, 2.6400, ..., 2.6400, 2.6400, 2.6400],
[ 2.6400, 2.6400, 2.6400, ..., 2.6400, 2.6400, 2.6400],
...,
[ 2.6400, 2.6400, 2.6400, ..., 2.6400, 2.6400, 2.6400],
[ 2.6400, 2.6400, 2.6400, ..., 2.6400, 2.6400, 2.6400],
[ 2.6400, 2.6400, 2.6400, ..., 2.6400, 2.6400, 2.6400]]]), 0)
是否有任何方法(最佳实践)来更改数据集中的示例数据,例如将标签0更改为标签1。以下操作无效:
image_datasets['val'][0] = (image_datasets['val'][0][0], 1)
答案 0 :(得分:1)
是的,尽管不是(轻松地)以编程方式。标签来自torchvision.datasets.ImageFolder,并且反映了数据集的目录结构(如在硬盘上看到的)。首先,我怀疑您可能想知道目录名作为字符串。这方面的文献很少,但是数据加载器具有classes
属性来存储这些属性。所以
img, lbl = image_datasets['val'][0]
directory_name = image_datasets['val'].classes[lbl]
如果您希望一致地返回这些ID而不是类ID,则可以按以下方式使用target_transform
API:
image_datasets['val'].target_transform = lambda id: image_datasets['val'].classes[id]
这将使加载器从现在开始返回字符串而不是ID。如果您正在寻找更高级的内容,则可以从ImageFolder
或DatasetFolder
重新实现/继承,并实现自己的语义。您唯一需要提供的方法是__len__
和__getitem__
。