我是pytorch的新手,想了解一些东西。
我正在按以下方式加载MNIST:
transform_train = transforms.Compose(
[transforms.ToTensor(),
transforms.Resize(size, interpolation=2),
# transforms.Grayscale(num_output_channels=1),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Normalize((mean), (std))])
trainset = torchvision.datasets.MNIST(root='./data', train=True,
download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
但是,当我浏览数据集trainloader.dataset.train_data[0]
时,我得到的张量在[0,255]范围内,形状为(28,28)。
我想念什么?这是因为转换没有直接应用于数据加载器,而是仅在运行时?否则我该如何浏览我的数据?
答案 0 :(得分:4)
在调用__getitem__
的{{1}}方法时应用转换。例如,查看Dataset
数据集类的__getitem__
方法:https://github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py#L62
MNIST
为训练集索引def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img.numpy(), mode='L')
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
实例时,将调用__getitem__
方法,例如:
MNIST
有关trainset[0]
的更多信息:https://docs.python.org/3.6/reference/datamodel.html#object.getitem
__getitem__
和Resize
应该在RandomHorizontalFlip
之前的原因是它们作用于PIL Images,并且Pytorch中的所有数据集为了保持一致性将数据加载为{{1 }}。实际上,您可以在这里看到他们通过以下方式强制执行该行为:
ToTensor
一旦您拥有相应索引的PIL Image
,就可以应用转换
img = Image.fromarray(img.numpy(), mode='L')
PIL Image
将if self.transform is not None:
img = self.transform(img)
转换为ToTensor
,PIL Image
减去平均值,然后除以您提供的标准差。
最终,某些转换会通过
应用于标签torch.Tensor
最后,返回处理后的图像和处理后的标签。所有这一切都在单个Normalize
调用中发生。
if self.target_transform is not None:
target = self.target_transform(target)
显示
trainset[key]