获取Torchvision预训练网络的分类标签

时间:2020-03-05 01:23:43

标签: image-processing deep-learning classification pytorch torchvision

Pytorch的{​​{1}}软件包提供pre-trained neural networks用于图像分类。我一直在使用以下代码使用Alexnet对图像进行分类(注意:其中一些代码来自this webpage):

torchvision

总共有1,000个类别,而from PIL import Image import torch from torchvision import transforms from torchvision import models # function to transform image transform = transforms.Compose([ transforms.Resize(224), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) # image img = Image.open('/path/to/image.jpg') img = transform(img) img = torch.unsqueeze(img, 0) # alexnet alexnet = models.alexnet(pretrained=True) alexnet.eval() out = alexnet(img) percents = torch.nn.functional.softmax(out, dim=1)[0] * 100 top5_vals, top5_inds = percents.topk(5) 变量为我提供了前5个类别的索引。但是如何获得相关的标签(例如蜗牛,篮球,香蕉)?我似乎找不到任何种类的列表作为Pytorch文档或top5_inds变量的一部分。

1 个答案:

答案 0 :(得分:2)

Torchvision模型在ImageNet数据集上进行了预训练。由于其全面性和规模,ImageNet是用于预培训和转移学习的最常用数据集。如您所述,它有1000个类。可以搜索完整的课程列表,或者您可以在GitHub上参考以下列表:https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a