保持模型在烧瓶中初始化

时间:2019-09-24 01:06:24

标签: python flask keras

我有一个问题,感觉很容易解决,但我不知道,我对flask还是很陌生。

我想拥有一个烧瓶应用程序,该应用程序允许用户上传图像并在经过特殊训练的keras模型上进行测试以检测猫的品种。

代码已经可以在不使用flask的情况下执行预测,并且如果在运行谓词之前就初始化了模型,我就可以使其运行。

我的目标是:如何使模型保持初始化状态,以便可以多次运行而不必每次都重新初始化它?

这是以下代码:

import os
#import magic
import urllib.request
from app import app
from flask import Flask, flash, request, redirect, render_template
from werkzeug.utils import secure_filename
import tensorflow as tf
import numpy as np
import requests
from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras.models import load_model
from tensorflow.keras.applications import xception
from PIL import Image

ALLOWED_EXTENSIONS = set(['txt', 'pdf', 'png', 'jpg', 'jpeg', 'gif'])

class CatClass:
    def __init__(self):
        from tensorflow.python.lib.io import file_io
        model_file = file_io.FileIO('gs://modelstorageforcats/94Kitty.h5', mode='rb')

        temp_model_location = './94Kitty.h5'
        temp_model_file = open(temp_model_location, 'wb')
        try:
            temp_model_file.write(model_file.read())
            temp_model_file.close()
            model_file.close()
        except:
            raise("Issues with getting the model")
        # get model
        self.catModel = load_model(temp_model_location)

    def catEstimator(self, catImage):
        script_dir = os.path.dirname(__file__) #<-- absolute dir the script is in
        rel_path = "uploads/" + catImage
        abs_file_pathk = os.path.join(script_dir, rel_path)

        #get picture the proper size for xception
        try:
            kittyPic = image.load_img(abs_file_pathk, target_size=(299,299))
            x = xception.preprocess_input(np.expand_dims(kittyPic.copy(), axis=0))
        except:
            raise("Error with the images")

        #cat names the way the model learned it
        catNames = ["Bengal","Abyssinian","BritishShorthair","Birman","Sphynx","Bombay","EgyptianMau","Persian","Ragdoll","MaineCoon","Siamese","RussianBlue","AmericanBobtail","DevonRex","AmericanCurl","DonSphynx","Manx","Balinese","Burmilla","Burmese","KhaoManee","Chausie","AmericanShortHair","Chartreux","Pixiebob","JapaneseBobtail","BritishLonghair","CornishRex","Tabby","Somali","ExoticShortHair","Tonkinese","OrientalShortHair","Minskin","Korat","Savannah","Havana","Singapura","Nebelung","OrientalLonghair","TurkishAngora","ScottishFold","KurilianBobtail","Lykoi","ScottishFoldLonghair","Ocicat","Munchkin","SelkirkRex","AustralianMist","AmericanWireHair","TurkishVan","SnowShoe","Peterbald","Siberian","Toybob","Himalayan","LePerm","NorwegianForestCat"]

        prediction = (self.catModel.predict(x))

        label = int(np.argmax(prediction, axis=-1))

        return(catNames[label])

catter = CatClass()


def allowed_file(filename):
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

@app.route('/')
def upload_form():
    return render_template('upload.html')

@app.route('/', methods=['POST'])
def upload_file():
    if request.method == 'POST':
        # check if the post request has the file part
        if 'file' not in request.files:
            flash('No file part')
            return redirect(request.url)
        file = request.files['file']
        if file.filename == '':
            flash('No file selected for uploading')
            return redirect(request.url)
        if file and allowed_file(file.filename):
            filename = secure_filename(file.filename)
            file.save(os.path.join(app.config['UPLOAD_FOLDER'], filename))
            flash('File successfully uploaded')
            flash(catter.catEstimator(filename))
            return redirect('/')
        else:
            flash('Allowed file types are txt, pdf, png, jpg, jpeg, gif')
            return redirect(request.url)

if __name__ == "__main__":
    app.run()

0 个答案:

没有答案