我编写了以下预测代码,这些代码根据训练有素的分类器模型进行预测。现在,预测时间大约是40秒,我希望尽可能减少。
我可以对推理脚本进行优化吗?还是应该在训练脚本中寻求发展?
import torch
import torch.nn as nn
from torchvision.models import resnet18
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable
import torch.functional as F
from PIL import Image
import os
import sys
import argparse
import time
import json
parser = argparse.ArgumentParser(description = 'To Predict from a trained model')
parser.add_argument('-i','--image', dest = 'image_name', required = True, help='Path to the image file')
args = parser.parse_args()
def predict_image(image_path):
print("prediciton in progress")
image = Image.open(image_path)
transformation = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image_tensor = transformation(image).float()
image_tensor = image_tensor.unsqueeze_(0)
if cuda:
image_tensor.cuda()
input = Variable(image_tensor)
output = model(input)
index = output.data.numpy().argmax()
return index
def parameters():
hyp_param = open('param_predict.txt','r')
param = {}
for line in hyp_param:
l = line.strip('\n').split(':')
def class_mapping(index):
with open("class_mapping.json") as cm:
data = json.load(cm)
if index == -1:
return len(data)
else:
return data[str(index)]
def segregate():
with open("class_mapping.json") as cm:
data = json.load(cm)
try:
os.mkdir(seg_dir)
print("Directory " , seg_dir , " Created ")
except OSError:
print("Directory " , seg_dir , " already created")
for x in range (0,len(data)):
dir_path="./"+seg_dir+"/"+data[str(x)]
try:
os.mkdir(dir_path)
print("Directory " , dir_path , " Created ")
except OSError:
print("Directory " , dir_path , " already created")
path_to_model = "./models/"+'trained.model'
checkpoint = torch.load(path_to_model)
seg_dir="segregation_folder"
cuda = torch.cuda.is_available()
num_class = class_mapping(index=-1)
print num_class
model = resnet18(num_classes = num_class)
if cuda:
model.load_state_dict(checkpoint)
else:
model.load_state_dict(checkpoint, map_location = 'cpu')
model.eval()
if __name__ == "__main__":
imagepath = "./Predict_Image/"+args.image_name
since = time.time()
img = Image.open(imagepath)
prediction = predict_image(imagepath)
name = class_mapping(prediction)
print("Time taken = ",time.time()-since)
print("Predicted Class: ",name)
可以在以下位置找到整个项目 https://github.com/amrit-das/custom_image_classifier_pytorch/
答案 0 :(得分:2)
在没有分析器的输出的情况下,很难分辨出其中的多少是由于代码效率低下所致。话虽这么说,PyTorch有很多启动开销-换句话说,与单个图像的推理时间相比,初始化库,模型,负载权重并将其传输到GPU的速度很慢。作为用于单图像预测的CLI实用程序,这使得它非常差。
如果您的用例确实需要使用单个图像而不是批处理,则没有太大的优化潜力。我看到的两个选择是