Keras模型无法通过函数调用预测

时间:2018-08-03 20:22:47

标签: python tensorflow keras

我正在像这样加载许多Keras模型:

from keras import backend as K # Tensorflow backend
from MiscFunctions import *
def main():
    for i in range(...):
        K.clear_session() # Needed to speed up model loading
        model = load_model(...)
        model._make_predict_function()
main()

但是,我稍后在脚本中有一个函数调用,该函数接受模型输入并从该模型输出预测。

length = get_length(model, ...)

这是get_length

的简化代码
def get_length(model, ...):
    ...
    # input_vector is the correct size
    return model.predict(np.asarray(input_vector).reshape(1,1,len(input_vector)))

除了prediction方法调用给我以外的错误:

tensorflow.python.framework.errors_impl.NotFoundError: FetchOutputs node dense_1/Softmax:0: not found
Exception tensorflow.python.framework.errors_impl.InvalidArgumentError: InvalidArgumentError() in <bound method _Callable.__del__ of <tensorflow.python.client.session._Callable object at 0x7f619b8c7e10>> ignored

我怀疑K.clear_session()行可能是造成此问题的原因,但是我需要清除会话以加快模型加载速度。我该如何解决这个问题?

1 个答案:

答案 0 :(得分:0)

为了有效地加载模型,请将其全局化并加载到另一个函数中,这样就不必一次又一次地加载它。全局设置后,可以在主要功能中对其进行访问:

def load_model():

    global model

    json_file = open('model.json', 'r')
    model_json = json_file.read()
    model = model_from_json(model_json)
    model.load_weights("model.h5")
    model._make_predict_function()