使用嵌入层还原keras模型

时间:2020-04-14 18:38:09

标签: tensorflow keras

任何人都只能帮助我恢复一次keras模型。这是我的代码

def universal_embedding(x):
  embed = hub.Module(url)
  return embed(tf.squeeze(tf.cast(x, tf.string), axis=[1]), signature="default", as_dict=True)["default"]


def create_model():
  input_text = layers.Input(shape=(1,), dtype=tf.string)
  embedding = layers.Lambda(universal_embedding, output_shape=(512,))(input_text)
  dense = layers.Dense(256, activation='relu')(embedding)
  pred = layers.Dense(4, activation='softmax')(dense)

  model = Model(inputs=[input_text], outputs=pred)
  model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
  model.summary()

  return model

@app.route("/predictLabel", methods=['GET', 'POST'])
@cross_origin()
def predict_label():
  model = create_model()

  request_data = json.loads(request.data)
  text = request_data['text']

  request_list = [text]

  request_list = np.array(request_list, dtype=object)[:, np.newaxis]

  with tf.compat.v1.Session() as session:
     K.set_session(session)
     session.run([tf.compat.v1.global_variables_initializer(), tf.compat.v1.tables_initializer()])
     model.load_weights('./model.h5')
     predicts = model.predict(request_list, batch_size=32)

现在模型加载每个请求,如何从头开始加载一次模型?我知道我们需要处理tensorflow会话和图形来做到这一点,但不确定如何实现它。非常感谢。

0 个答案:

没有答案