我正在尝试通过加载程序进行迭代以检查其是否正常运行,但是给出了以下错误:
TypeError: img should be PIL Image. Got <class 'torch.Tensor'>
我尝试同时添加transforms.ToTensor()
和transforms.ToPILImage()
,这给我一个错误,要求相反。即使用ToPILImage()
,它将要求张量,反之亦然。
# Imports here
%matplotlib inline
import matplotlib.pyplot as plt
from torch import nn, optim
import torch.nn.functional as F
import torch
from torchvision import transforms, datasets, models
import seaborn as sns
import pandas as pd
import numpy as np
data_dir = 'flowers'
train_dir = data_dir + '/train'
valid_dir = data_dir + '/valid'
test_dir = data_dir + '/test'
#Creating transform for training set
train_transforms = transforms.Compose(
[transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.RandomHorizontalFlip(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
#Creating transform for test set
test_transforms = transforms.Compose(
[transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])
#transforming for all data
train_data = datasets.ImageFolder(train_dir, transform=train_transforms)
test_data = datasets.ImageFolder(test_dir, transform = test_transforms)
valid_data = datasets.ImageFolder(valid_dir, transform = test_transforms)
#Creating data loaders for test and training sets
trainloader = torch.utils.data.DataLoader(train_data, batch_size = 32,
shuffle = True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=32)
images, labels = next(iter(trainloader))
如果plt.imshow(images[0])
运行正常,它应该可以让我简单地看到图像。
答案 0 :(得分:3)
transforms.RandomHorizontalFlip()
适用于PIL.Images
,而不适用于torch.Tensor
。在上面的代码中,您要在transforms.ToTensor()
之前应用transforms.RandomHorizontalFlip()
,这会导致张量。
但是,根据官方pytorch文档here
transforms.RandomHorizontalFlip()水平翻转给定的PIL 以给定的概率随机拍摄图像。
因此,只需在上面的代码中更改转换的顺序,如下所示:
train_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
答案 1 :(得分:2)
只需添加transforms.ToPILImage()
即可转换为pil映像,然后它将起作用,例如:
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.RandomHorizontalFlip(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])