我有140K句子要嵌入。我正在使用TF_HUB通用语句编码器并正在对句子进行迭代(我知道这不是最好的方法,但是当我尝试向模型中输入500个以上的句子时,它会崩溃)。 我的环境是: Ubuntu 18.04 Python 3.7.4 TF 1.14 内存:16GB 处理器:i-5
我的代码是:
版本1 我在tf.session上下文管理器中进行了迭代
embed = hub.Module("https://tfhub.dev/google/universal-sentence-encoder-large/3")
df = pandas_repository.get_dataframe_from_table('sentences')
with tf.compat.v1.Session() as session:
session.run(tf.global_variables_initializer())
session.run(tf.tables_initializer())
sentence_embedding = None
for i, row in df.iterrows():
sentence = row['content']
embeddings = embed([sentence])
sentence_embedding = session.run(embeddings)
df.at[i, 'embedding'] = sentence_embedding
print('processed index:', i)
版本2 我每次迭代都打开和关闭一个会话
embed = hub.Module("https://tfhub.dev/google/universal-sentence-encoder-large/3")
df = pandas_repository.get_dataframe_from_table('sentences')
for i, row in df.iterrows():
sentence = row['content']
embeddings = embed([sentence])
sentence_embedding = None
with tf.compat.v1.Session() as session:
session.run(tf.global_variables_initializer())
session.run(tf.tables_initializer())
sentence_embedding = session.run(embeddings)
df.at[i, 'embedding'] = sentence_embedding
print('processed index:', i)
虽然版本2 确实具有某种GC,但内存已被清除一些。它仍然超过50件并爆炸。
版本1 仅用于记忆。
提供的正确解决方案def calculate_embeddings(dataframe, table_name):
sql_get_sentences = "SELECT * FROM semantic_similarity.sentences WHERE embedding IS NULL LIMIT 1500"
sql_update = 'UPDATE {} SET embedding = data.embedding FROM (VALUES %s) AS data(id, embedding) WHERE {}.id = data.id'.format(table_name, table_name)
df = pandas_repository.get_dataframe_from_sql(sql_get_sentences)
with hub.eval_function_for_module("https://tfhub.dev/google/universal-sentence-encoder-large/3") as embed:
while len(df) >= 0:
sentence_array = df['content'].values
sentence_embeddings = embed(sentence_array)
df['embedding'] = sentence_embeddings.tolist()
values = [tuple(x) for x in df[['id', 'embedding']].values]
pandas_repository.update_db_from_df('semantic_similarity.sentences', sql_update, values)
df = pandas_repository.get_dataframe_from_sql(sql_get_sentences)
我是TF的新手,可以使用我所能获得的任何帮助。
答案 0 :(得分:1)
您的代码使用tf.Session,因此它属于TF1.x编程模型,该模型首先构建一个数据流图,然后在馈入输入并从该图获取输出的情况下重复运行它。
但是您的代码与该编程模型不太吻合。这两个版本都继续向Hubs.Module添加新应用程序(调用)到默认的TensorFlow图,而不是一次应用并为各种输入重复运行同一图。版本2不断地进入和退出tf.Sessions,这释放了一些内存,但是效率很低。
请参阅我对“ Strongly increasing memory consumption when using ELMo from Tensorflow-Hub”的回答,以获取有关如何在TensorFlow 1.x的基于图的编程模型中正确进行操作的指导。
TensorFlow 2.0(即将发布)默认为“渴望执行”的编程模型,该模型消除了图形和会话,并避免了这种混乱。 TensorFlow Hub将适时更新为TF2.0。有关您的用例的预览,请参见https://colab.research.google.com/github/tensorflow/hub/blob/master/examples/colab/tf2_text_classification.ipynb