我正在像这样加载许多Keras模型:
from keras import backend as K # Tensorflow backend
from MiscFunctions import *
def main():
for i in range(...):
K.clear_session() # Needed to speed up model loading
model = load_model(...)
model._make_predict_function()
main()
但是,我稍后在脚本中有一个函数调用,该函数接受模型输入并从该模型输出预测。
length = get_length(model, ...)
这是get_length
def get_length(model, ...):
...
# input_vector is the correct size
return model.predict(np.asarray(input_vector).reshape(1,1,len(input_vector)))
除了prediction
方法调用给我以外的错误:
tensorflow.python.framework.errors_impl.NotFoundError: FetchOutputs node dense_1/Softmax:0: not found
Exception tensorflow.python.framework.errors_impl.InvalidArgumentError: InvalidArgumentError() in <bound method _Callable.__del__ of <tensorflow.python.client.session._Callable object at 0x7f619b8c7e10>> ignored
我怀疑K.clear_session()
行可能是造成此问题的原因,但是我需要清除会话以加快模型加载速度。我该如何解决这个问题?
答案 0 :(得分:0)
为了有效地加载模型,请将其全局化并加载到另一个函数中,这样就不必一次又一次地加载它。全局设置后,可以在主要功能中对其进行访问:
def load_model():
global model
json_file = open('model.json', 'r')
model_json = json_file.read()
model = model_from_json(model_json)
model.load_weights("model.h5")
model._make_predict_function()