这是一个代码示例:
dataset = datasets.MNIST(root=root, train=istrain, transform=None) #preserve raw img
print(type(dataset[0][0]))
# <class 'PIL.Image.Image'>
dataset = torch.utils.data.Subset(dataset, indices=SAMPLED_INDEX) # for resample
for ind in range(len(dataset)):
img, label = dataset[ind] # <class 'PIL.Image.Image'> <class 'int'>/<class 'numpy.int64'>
img.save(fp=os.path.join(saverawdir, f'{ind:02d}-{int(label):02d}.png'))
dataset.transform = transforms.Compose([
transforms.RandomResizedCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
#transform for net forwarding
print(type(dataset[0][0]))
# expected <class 'torch.Tensor'>, however it's still <class 'PIL.Image.Image'>
由于数据集是随机抽样的,所以我不想用transform
重新加载新的数据集,而只是将transform
应用于已经存在的数据集。
感谢您的帮助:D