使用ResNet50加载“ Imagenet”的权重时,每次加载权重时几乎都要花费10-11秒。 有什么方法可以减少加载时间?
代码:
from flask import Flask, render_template, request
from werkzeug import secure_filename
from flask import request,Flask
import json
import os
import time
from keras.preprocessing import image as image_util
from keras.applications.imagenet_utils import preprocess_input
from keras.applications.imagenet_utils import decode_predictions
# from keras.applications import ResNet50
from keras.applications.inception_v3 import InceptionV3
import numpy as np
app = Flask(__name__)
@app.route('/object_rec', methods=['POST'])
def object_rec():
f = request.files['file']
file_path = ("./upload/"+secure_filename(f.filename))
f.save(file_path)
image = image_util.load_img(file_path,target_size=(299,299))
image = image_util.img_to_array(image)
image = np.expand_dims(image,axis=0) #(224,224,3) --> (1,224,224,3)
image = preprocess_input(image)
start_time = time.time()
model = InceptionV3(weights="imagenet")
pred = model.predict(image)
p = decode_predictions(pred)
ans = p[0][0]
acc = ans[2]
acc = str(acc)
if ans[1] == "Granny_Smith":
ans = ans[1]
ans = 'Apple'
else:
ans = ans[1]
print("THE PREDICTED IMAGE IS: "+ans)
print("THE ACCURACY IS: "+acc)
print("--- %s seconds ---" % (time.time() - start_time))
result = {
"status": True,
"object": ans,
"score":acc
}
result = json.dumps(result)
return result
if __name__ == '__main__':
app.run(host='0.0.0.0',port=6000,debug=True)
花费的时间在8-11秒之间。 如果它在3-4秒内加载模型并进行分类,我会很好。
预先感谢
答案 0 :(得分:0)
您可以通过在特定的会话中加载模型,然后在每次使用该模型时设置该特定的会话,然后仅在需要的位置调用预测:
app = Flask(__name__)
sess = tf.Session(config=tf_config)
graph = tf.get_default_graph()
# IMPORTANT: models have to be loaded AFTER SETTING THE SESSION for keras!
# Otherwise, their weights will be unavailable in the threads after the
session there has been set
set_session(sess)
model = InceptionV3(weights="imagenet")
@app.route('/object_rec', methods=['POST'])
def object_rec():
global sess
global graph
with graph.as_default():
set_session(sess)
model.predict(...)
if __name__ == '__main__':
app.run(host='0.0.0.0',port=6000,debug=True)