我按照本文(https://colab.research.google.com/drive/1Odry08Jm0f_YALhAt4vp9qa5w8prUzDY#scrollTo=_stfC_7VFhS8)的方法使用通用语句编码器构建模型。基本上,培训过程是
module_url = "https://tfhub.dev/google/universal-sentence-encoder-large/3"
embed = hub.Module(module_url)
def UniversalEmbedding(x):
return embed(tf.squeeze(tf.cast(x, tf.string)), signature="default", as_dict=True)["default"]
input_text = layers.Input(shape=(1,), dtype=tf.string)
embedding = layers.Lambda(UniversalEmbedding, output_shape=(embed_size,))(input_text)
dense = layers.Dense(256, activation='relu')(embedding)
pred = layers.Dense(category_counts, activation='softmax')(dense)
model = Model(inputs=[input_text], outputs=pred)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()
with tf.Session() as session:
K.set_session(session)
session.run(tf.global_variables_initializer())
session.run(tf.tables_initializer())
history = model.fit(train_text,
train_label,
validation_data=(test_text, test_label),
epochs=10,
batch_size=32)
model.save_weights('./model.h5')
我保存了模型权重,希望以后再使用。我可以通过以下方式加载权重进行建模 model.load_weights(self.model_weight_file) 但是,每次我需要做出预测时,我都需要创建一个会话
with tf.Session() as session:
K.set_session(session)
session.run(tf.global_variables_initializer())
session.run(tf.tables_initializer())
predicts = self.model.predict(text)
会话初始化需要很长时间,如果我需要在每个预测中创建一个会话,这会给我带来巨大的开销。有没有一种方法可以一次创建会话并在每次预测中坚持重用它?