为什么我的烧瓶 API 不返回类标签?

时间:2020-12-18 19:20:34

标签: python machine-learning flask deep-learning

我有一个 Flask API,它接收图像,并应该使用预训练模型和 imagenet 类索引输出其类的预测。

我知道我的请求脚本正在调用 API /predict 端点,因为我在 API 端获得此输出

127.0.0.1 - - [18/Dec/2020 19:15:08] "←[37mPOST /predict HTTP/1.1←[0m" 200 -```

当我对如下内容进行硬编码时,我可以获得预测,但我不确定如何将其转换为 API:

imagenet_class_index = json.load(open('./static/imagenet_class_index.json'))

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

with open("img059.jpg", 'rb') as f:
    image_bytes = f.read()
    print(get_prediction(image_bytes = image_bytes))

这是我的 API 的精简版

import io
import json

from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request



app = Flask(__name__)
imagenet_class_index = json.load(open('./static/imagenet_class_index.json'))
model = models.densenet121(pretrained=True)
model.eval()


def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(244),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes = image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]


@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        # we will get the file from the request 
        file = request.files['file']
        # convert file to bytes
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes = img_bytes)
        return jsonify({'class_id' : class_id, 'class_name' : class_name})

@app.route('/')
def base_route():
    return 'Greetings, Traveller!'

if __name__ == '__main__':
    app.run()

编辑:基本路由日志

127.0.0.1 - - [18/Dec/2020 19:14:59] "←[37mGET / HTTP/1.1←[0m" 200 -

请求.py

import requests

resp = requests.post("http://localhost:5000/predict",
                     files={"file": open('img059.jpg','rb')})

1 个答案:

答案 0 :(得分:1)

我相信您根本没有打印响应。您的客户端脚本应该是

请求.py

import requests

resp = requests.post("http://localhost:5000/predict",
                     files={"file": open('img059.jpg','rb')})

result = resp.json()
print("Class Id:{}, Class Name: {}".format(result["class_id"],result["class_name"]))

您现在应该可以看到结果了。