使用tf-hub模型时如何避免每次预测都进行初始化会话

时间:2019-11-15 22:07:27

标签: tensorflow tf-hub

我按照本文(https://colab.research.google.com/drive/1Odry08Jm0f_YALhAt4vp9qa5w8prUzDY#scrollTo=_stfC_7VFhS8)的方法使用通用语句编码器构建模型。基本上,培训过程是

module_url = "https://tfhub.dev/google/universal-sentence-encoder-large/3" 
embed = hub.Module(module_url)
def UniversalEmbedding(x):
    return embed(tf.squeeze(tf.cast(x, tf.string)), signature="default", as_dict=True)["default"]

input_text = layers.Input(shape=(1,), dtype=tf.string)
embedding = layers.Lambda(UniversalEmbedding, output_shape=(embed_size,))(input_text)
dense = layers.Dense(256, activation='relu')(embedding)
pred = layers.Dense(category_counts, activation='softmax')(dense)
model = Model(inputs=[input_text], outputs=pred)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()
with tf.Session() as session:
   K.set_session(session)
   session.run(tf.global_variables_initializer())
   session.run(tf.tables_initializer())
   history = model.fit(train_text, 
        train_label,
        validation_data=(test_text, test_label),
        epochs=10,
        batch_size=32)
  model.save_weights('./model.h5')

我保存了模型权重,希望以后再使用。我可以通过以下方式加载权重进行建模 model.load_weights(self.model_weight_file) 但是,每次我需要做出预测时,我都需要创建一个会话

with tf.Session() as session:
    K.set_session(session)
    session.run(tf.global_variables_initializer())
    session.run(tf.tables_initializer())
    predicts = self.model.predict(text)

会话初始化需要很长时间,如果我需要在每个预测中创建一个会话,这会给我带来巨大的开销。有没有一种方法可以一次创建会话并在每次预测中坚持重用它?

0 个答案:

没有答案