加载keras模型并将其缓存在变量中,而无需重新加载

时间:2019-03-27 13:41:54

标签: python tensorflow keras

在我的Flask应用程序开始时加载模型,然后将其用于端点中的预测会导致错误

'ValueError:Tensor Tensor(“ dense / Softmax:0”,shape =(?, 4),dtype = float32)不是此图的元素。'

model = keras.models.load_model("model.h5")

@app.route("/predict", methods=["POST"])
def predict():
    json_data = request.get_json()

    variable = preparePredictionInput(
        [variable], alphabetDict, maxVariableLength)
    prediction = list(model.predict(variable, steps=1, verbose=1)[0])

但是每次调用预测端点时都加载keras模型似乎工作正常

@app.route("/predict", methods=["POST"])
def predict():
    json_data = request.get_json()
    model = keras.models.load_model("model.h5")

    variable = preparePredictionInput(
        [variable], alphabetDict, maxVariableLength)
    prediction = list(model.predict(variable, steps=1, verbose=1)[0]) 

有没有办法解决这个问题?这从根本上降低了每次都要重新加载模型的性能。

1 个答案:

答案 0 :(得分:0)

像模型变量这样的符号不是全局的。看看下面的代码:

def init():
  global model
  model = lkeras.models.load_model("model.h5")

@app.route("/predict", methods=["POST"])
def predict():
    json_data = request.get_json()
    variable = preparePredictionInput([variable], alphabetDict, maxVariableLength)
    prediction = list(model.predict(variable, steps=1, verbose=1)[0])


if __name__ == "__main__":
    init()
    app.run()