如何减少预训练模型的加载时间?

时间:2020-02-12 11:01:41

标签: python keras deep-learning resnet transfer-learning

使用ResNet50加载“ Imagenet”的权重时,每次加载权重时几乎都要花费10-11秒。 有什么方法可以减少加载时间?

代码:

from flask import Flask, render_template, request
from werkzeug import secure_filename
from flask import request,Flask
import json
import os
import time

from keras.preprocessing import image as image_util 
from keras.applications.imagenet_utils import preprocess_input
from keras.applications.imagenet_utils import decode_predictions
# from keras.applications import ResNet50
from keras.applications.inception_v3 import InceptionV3
import numpy as np

app = Flask(__name__)

@app.route('/object_rec', methods=['POST'])
def object_rec():

      f = request.files['file']
      file_path = ("./upload/"+secure_filename(f.filename))
      f.save(file_path)
      image = image_util.load_img(file_path,target_size=(299,299))
      image = image_util.img_to_array(image)
      image = np.expand_dims(image,axis=0) #(224,224,3) --> (1,224,224,3)
      image = preprocess_input(image)

      start_time = time.time()
      model = InceptionV3(weights="imagenet")
      pred = model.predict(image)
      p = decode_predictions(pred)

      ans = p[0][0]
      acc = ans[2]
      acc = str(acc)
      if ans[1] == "Granny_Smith":
            ans = ans[1]
            ans = 'Apple'
      else:
            ans = ans[1]
      print("THE PREDICTED IMAGE IS: "+ans)
      print("THE ACCURACY IS: "+acc)
      print("--- %s seconds ---" % (time.time() - start_time))
      result = {
            "status": True,
            "object": ans,
            "score":acc
      }
      result = json.dumps(result)
      return result

if __name__ == '__main__':
   app.run(host='0.0.0.0',port=6000,debug=True)

花费的时间在8-11秒之间。 如果它在3-4秒内加载模型并进行分类,我会很好。

预先感谢

1 个答案:

答案 0 :(得分:0)

您可以通过在特定的会话中加载模型,然后在每次使用该模型时设置该特定的会话,然后仅在需要的位置调用预测:

app = Flask(__name__)
sess = tf.Session(config=tf_config)
graph = tf.get_default_graph()

# IMPORTANT: models have to be loaded AFTER SETTING THE SESSION for keras! 
# Otherwise, their weights will be unavailable in the threads after the 
session there has been set
set_session(sess)

model = InceptionV3(weights="imagenet")

@app.route('/object_rec', methods=['POST'])
def object_rec():
   global sess
   global graph
   with graph.as_default():
      set_session(sess)
      model.predict(...)

if __name__ == '__main__':
   app.run(host='0.0.0.0',port=6000,debug=True)