我已经训练了一个模型(通过Keras框架),并用model.save('model.hdf5')
导出了模型,现在我想将其与很棒的Streamlit集成。
显然,我不想每次最终用户插入新输入时都加载模型,而是一劳永逸地加载它。
所以我的代码看起来像这样:
@st.cache
def load_my_model():
model = load_model('model.hdf5')
model.summary()
return model
if __name__ == '__main__':
st.title('My first app')
sentence = st.text_input('Input your sentence here:')
model = load_my_model()
if sentence:
y_hat = model.predict(sentence)
这样,我得到了:
“ streamlit.errors.UnhashableType:”
例外。
我尝试使用@st.cache(allow_output_mutation=True)
并在精简页面上运行查询时。我得到了:
“ TypeError:无法将feed_dict键解释为张量:Tensor Tensor(” input_1:0“,shape =(?, 80),dtype = int32)不是该图的元素。”
(当然,没有任何缓存装饰器的模型将被加载并正常工作)
我应该如何正确加载并缓存 Keras训练的模型?
非常感谢!
答案 0 :(得分:0)
解决方案是:
_make_predict_function()
通话from keras import backend as K
@st.cache(allow_output_mutation=True)
def load_model():
model = load_model(MODEL_PATH)
model._make_predict_function()
model.summary() # included to make it visible when model is reloaded
session = K.get_session()
return model, session
if __name__ == '__main__':
st.title('My first app')
sentence = st.text_input('Input your sentence here:')
model, session = load_model()
if sentence:
K.set_session(session)
y_hat = model.predict(sentence)