Streamlit缓存Keras训练模型

时间:2020-04-04 15:29:57

标签: keras streamlit

我已经训练了一个模型(通过Keras框架),并用model.save('model.hdf5')导出了模型,现在我想将其与很棒的Streamlit集成。 显然,我不想每次最终用户插入新输入时都加载模型,而是一劳永逸地加载它。 所以我的代码看起来像这样:

@st.cache
def load_my_model():
    model = load_model('model.hdf5')
    model.summary()

    return model

if __name__ == '__main__':
    st.title('My first app')
    sentence = st.text_input('Input your sentence here:')
    model = load_my_model()
    if sentence:
        y_hat = model.predict(sentence)

这样,我得到了:

“ streamlit.errors.UnhashableType:”

例外。 我尝试使用@st.cache(allow_output_mutation=True)并在精简页面上运行查询时。我得到了:

“ TypeError:无法将feed_dict键解释为张量:Tensor Tensor(” input_1:0“,shape =(?, 80),dtype = int32)不是该图的元素。”

(当然,没有任何缓存装饰器的模型将被加载并正常工作)

我应该如何正确加载并缓存 Keras训练的模型?

  • Python版本:2.7(不幸的是)
  • Keras版本:2.1.3
  • Tensorflow版本:1.3.0
  • Streamlit版本:0.55.2

非常感谢!

1 个答案:

答案 0 :(得分:0)

解决方案是:

  1. 添加_make_predict_function()通话
  2. 返回会话
from keras import backend as K

@st.cache(allow_output_mutation=True)
def load_model():
    model = load_model(MODEL_PATH)
    model._make_predict_function()
    model.summary()  # included to make it visible when model is reloaded
    session = K.get_session()
    return model, session

if __name__ == '__main__':
    st.title('My first app')
    sentence = st.text_input('Input your sentence here:')
    model, session = load_model()
    if sentence:
        K.set_session(session)
        y_hat = model.predict(sentence)