转换不适用于数据集

时间:2018-08-31 18:19:43

标签: python pytorch

我是pytorch的新手,想了解一些东西。

我正在按以下方式加载MNIST:

transform_train = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize(size, interpolation=2),
     # transforms.Grayscale(num_output_channels=1),
     transforms.RandomHorizontalFlip(p=0.5),
     transforms.Normalize((mean), (std))])


trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

但是,当我浏览数据集trainloader.dataset.train_data[0]时,我得到的张量在[0,255]范围内,形状为(28,28)。

我想念什么?这是因为转换没有直接应用于数据加载器,而是仅在运行时?否则我该如何浏览我的数据?

1 个答案:

答案 0 :(得分:4)

在调用__getitem__的{​​{1}}方法时应用转换。例如,查看Dataset数据集类的__getitem__方法:https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py#L62

MNIST

为训练集索引def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ img, target = self.data[index], self.targets[index] # doing this so that it is consistent with all other datasets # to return a PIL Image img = Image.fromarray(img.numpy(), mode='L') if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target 实例时,将调用__getitem__方法,例如:

MNIST

有关trainset[0] 的更多信息:https://docs.python.org/3.6/reference/datamodel.html#object.getitem

__getitem__Resize应该在RandomHorizontalFlip之前的原因是它们作用于PIL Images,并且Pytorch中的所有数据集为了保持一致性将数据加载为{{1 }}。实际上,您可以在这里看到他们通过以下方式强制执行该行为:

ToTensor

一旦您拥有相应索引的PIL Image,就可以应用转换

img = Image.fromarray(img.numpy(), mode='L')

PIL Imageif self.transform is not None: img = self.transform(img) 转换为ToTensorPIL Image减去平均值,然后除以您提供的标准差。

最终,某些转换会通过

应用于标签
torch.Tensor

最后,返回处理后的图像和处理后的标签。所有这一切都在单个Normalize调用中发生。

if self.target_transform is not None:
    target = self.target_transform(target)

显示

trainset[key]