我用pytorch的预训练vgg16模型编写了一个图像vgg分类模型。
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
import urllib
from skimage.transform import resize
from skimage import io
import yaml
# Downloading imagenet 1000 classes list
file = urllib. request. urlopen("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt")
classes = ''
for f in file:
classes = classes + f.decode("utf-8")
classes = yaml.load(classes)
# Downloading pretrained vgg16 model
model = torch.hub.load('pytorch/vision:v0.6.0', 'vgg16', pretrained=True)
print(model)
for param in model.parameters():
param.requires_grad = False
url, filename = ("https://raw.githubusercontent.com/pytorch/hub/master/dog.jpg", "dog.jpg")
image=io.imread(url)
plt.imshow(image)
plt.show()
# resize to 224x224x3
img = resize(image,(224,224,3))
plt.imshow(img)
plt.show()
# Normalizing input for vgg16
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
img1 = mean*img+std
img1 = np.clip(img1,0,1)
img1 = torch.from_numpy(img1).unsqueeze(0)
img1 = img1.permute(0,3,2,1) # batch_size x channels x height x width
model.eval()
pred = model(img1.float())
print(classes[torch.argmax(pred).numpy().tolist()])
代码工作正常,但是输出错误的类。我不确定我在哪里做错了,但是如果我不得不猜测那可能是imagenet yaml类列表或归一化输入图像。谁能告诉我我在哪里犯错误?
答案 0 :(得分:1)
图像预处理存在一些问题。首先,归一化计算为(value - mean) / std)
,而不是value * mean + std
。其次,不应将值裁剪为[0,1],归一化故意将值从[0,1]移开。其次,作为NumPy数组的图像的形状为 [height,width,3] ,当您置换尺寸时,交换高度和宽度尺寸,创建形状为 [batch_size,通道,宽度,高度] 。
img = resize(image,(224,224,3))
# Normalizing input for vgg16
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
img1 = (img1 - mean) / std
img1 = torch.from_numpy(img1).unsqueeze(0)
img1 = img1.permute(0, 3, 1, 2) # batch_size x channels x height x width
您可以使用torchvision.transforms
来代替手动进行操作。
from torchvision import transforms
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img = resize(image,(224,224,3))
img1 = preprocess(img)
img1 = img1.unsqueeze(0)
如果您使用PIL加载图像,还可以通过在预处理管道中添加transforms.Resize((224, 224))
来调整图像的大小,或者甚至可以添加transforms.ToPILImage()
首先将图像转换为PIL图像( transforms.Resize
需要PIL图像。