PyTorch DatasetLoader张量应为割炬张量。获得<class'PIL.Image.Image'>

时间:2020-08-09 22:09:37

标签: python computer-vision pytorch python-imaging-library

data_dir="D:\ML-ComputerVision\Datasets"
train_transforms=transforms.Compose([transforms.RandomRotation(30),
                                     transforms.RandomResizedCrop(100),
                                     transforms.RandomHorizontalFlip(),
                                    transforms.Normalize([0.5,],[0.5,]),
                                    transforms.ToTensor()])

test_transforms=transforms.Compose([transforms.Normalize([0.5,],[0.5,]),
                                    transforms.ToTensor()])

train_data=datasets.ImageFolder(data_dir + "/Train",transform=train_transforms)

test_data=datasets.ImageFolder(data_dir + "/Test",transform=test_transforms)

trainloader=torch.utils.data.DataLoader(train_data,batch_size=32,shuffle=True)

testloader=torch.utils.data.DataLoader(test_data,batch_size=32,shuffle=False)

images, labels = next(iter(trainloader)) # <-- Error line

我得到的张量应该是火炬张量。即使实现了transforms.ToTensor(),也出现了错误。有什么想法可以解决吗?

2 个答案:

答案 0 :(得分:0)

transforms.Normalize之前应用transforms.ToTensor链接转换的方式。但是,尽管RandomRotationRandomResizedCropRandomHorizontalFlip是适用于PIL图像的图像变换,但transforms.Normalize仅适用于张量(文档here)。

只需将ToTensor放在Normalize之前即可。

答案 1 :(得分:0)

transforms.Normalize([0.5,],[0.5,])应用于Tensor,因此transforms.ToTensor()必须在规范化之前。