预先训练的pytorch vgg16模型分类及其类别

时间:2020-06-20 06:15:18

标签: pytorch classification torch vgg-net torchvision

我用pytorch的预训练vgg16模型编写了一个图像vgg分类模型。

import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
import urllib
from skimage.transform import resize
from skimage import io
import yaml

# Downloading imagenet 1000 classes list
file = urllib. request. urlopen("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt")
classes = ''
for f in file:
  classes = classes +  f.decode("utf-8")
classes = yaml.load(classes)

# Downloading pretrained vgg16 model
model = torch.hub.load('pytorch/vision:v0.6.0', 'vgg16', pretrained=True)

print(model)

for param in model.parameters():
    param.requires_grad = False


url, filename = ("https://raw.githubusercontent.com/pytorch/hub/master/dog.jpg", "dog.jpg")

image=io.imread(url)

plt.imshow(image)
plt.show()

# resize to 224x224x3
img = resize(image,(224,224,3))

plt.imshow(img)
plt.show()
# Normalizing input for vgg16
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
img1 = mean*img+std
img1 = np.clip(img1,0,1)

img1 = torch.from_numpy(img1).unsqueeze(0)
img1 = img1.permute(0,3,2,1) # batch_size x channels x height x width

model.eval()
pred = model(img1.float())
print(classes[torch.argmax(pred).numpy().tolist()])

代码工作正常,但是输出错误的类。我不确定我在哪里做错了,但是如果我不得不猜测那可能是imagenet yaml类列表或归一化输入图像。谁能告诉我我在哪里犯错误?

1 个答案:

答案 0 :(得分:1)

图像预处理存在一些问题。首先,归一化计算为(value - mean) / std),而不是value * mean + std。其次,不应将值裁剪为[0,1],归一化故意将值从[0,1]移开。其次,作为NumPy数组的图像的形状为 [height,width,3] ,当您置换尺寸时,交换高度和宽度尺寸,创建形状为 [batch_size,通道,宽度,高度]

img = resize(image,(224,224,3))


# Normalizing input for vgg16
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
img1 = (img1 - mean) / std

img1 = torch.from_numpy(img1).unsqueeze(0)
img1 = img1.permute(0, 3, 1, 2) # batch_size x channels x height x width

您可以使用torchvision.transforms来代替手动进行操作。

from torchvision import transforms

preprocess = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

img = resize(image,(224,224,3))
img1 = preprocess(img)
img1 = img1.unsqueeze(0)

如果您使用PIL加载图像,还可以通过在预处理管道中添加transforms.Resize((224, 224))来调整图像的大小,或者甚至可以添加transforms.ToPILImage()首先将图像转换为PIL图像( transforms.Resize需要PIL图像。