使用Tensorflow-Hub的ELMo时会大大增加内存消耗

时间:2019-06-07 06:14:19

标签: python tensorflow tensorflow-hub elmo

我目前正在尝试比较数百万个文档的相似性。对于在CPU上的第一个测试,我将它们每个减少到大约50个字符,并尝试一次获得其中10个的ELMo嵌入:

ELMO = "https://tfhub.dev/google/elmo/2"
for row in file:
    split = row.split(";", 1)
    if len(split) > 1:
        text = split[1].replace("\n", "")
            texts.append(text[:50])
    if i == 300:
        break
    if i % 10 == 0:
        elmo = hub.Module(ELMO, trainable=False)
                 executable = elmo(
                 texts,
                 signature="default",
                 as_dict=True)["elmo"]

    vectors = execute(executable)
    texts = []
    i += 1

但是,即使是这个小例子,在大约300个句子之后(甚至不保存向量),程序仍要消耗多达12GB的RAM。这是一个已知问题(我发现的其他问题表明存在类似的问题,但还不是那么极端)还是我犯了一个错误?

1 个答案:

答案 0 :(得分:1)

我想这是针对没有Eager模式的TensorFlow 1.x(否则使用hub.Module可能会遇到更大的问题)。

在该编程模型中,您需要首先在TensorFlow图中表达您的计算,然后对每批数据重复执行该图。

  • 使用hub.Module()构造模块并将其应用于将输入张量映射到输出张量都是图构建的一部分,并且应该只发生一次。

  • 输入数据上的循环应仅调用session.run()来馈送输入并从固定图中获取输出数据。

幸运的是,已经有一个实用程序函数可以为您完成所有这些操作:

import numpy as np
import tensorflow_hub as hub

# For demo use only. Extend to your actual I/O needs as you see fit.
inputs = (x for x in ["hello world", "quick brown fox"])

with hub.eval_function_for_module("https://tfhub.dev/google/elmo/2") as f:
  for pystr in inputs:
    batch_in = np.array([pystr])
    batch_out = f(batch_in)
    print(pystr, "--->", batch_out[0])

就原始TensorFlow而言,这对您的作用大致是这样:

module = Module(ELMO_OR_WHATEVER)
tensor_in = tf.placeholder(tf.string, shape=[None])  # As befits `module`.
tensor_out = module(tensor_in)

# This kind of session handles init ops for you.
with tf.train.SingularMonitoredSession() as sess:
  for pystr in inputs:
    batch_in = np.array([pystr])
    batch_out = sess.run(tensor_out, feed_dict={tensor_in: batch_in}
    print(pystr, "--->", batch_out[0])

如果您的需求对于with hub.eval_function_for_module ...而言过于复杂,则可以构建更明确的示例。

注意在循环中既不构造也不调用hub.Module。

PS:厌倦了担心构建图形还是运行会话的烦恼?然后TF2和热切的执行适合您。检出https://colab.research.google.com/github/tensorflow/hub/blob/master/examples/colab/tf2_text_classification.ipynb