加载预先训练的word2vec以在Estimator model_fn中初始化embedding_lookup

时间:2017-06-21 15:46:10

标签: tensorflow word2vec google-cloud-ml-engine

我正在解决文本分类问题。我使用Estimator类和我自己的model_fn来定义我的分类器。我想使用Google预先训练好的word2vec嵌入作为初始值,然后针对手头的任务进一步优化它。

我看到这篇文章:Using a pre-trained word embedding (word2vec or Glove) in TensorFlow
这解释了如何在'原始'TensorFlow代码中进行处理。但是,我真的想使用Estimator类。

作为扩展,我想在Cloud ML Engine上训练此代码,是否有一种传递具有初始值的相当大的文件的好方法?

假设我们有类似的东西:

def build_model_fn():
    def _model_fn(features, labels, mode, params):
        input_layer = features['feat'] #shape=[-1, params["sequence_length"]]
        #... what goes here to initialize W

        embedded = tf.nn.embedding_lookup(W, input_layer)
        ...
        return predictions

estimator = tf.contrib.learn.Estimator(
    model_fn=build_model_fn(),
    model_dir=MODEL_DIR,
    params=params)
estimator.fit(input_fn=read_data, max_steps=2500)

1 个答案:

答案 0 :(得分:8)

嵌入通常足够大,唯一可行的方法是使用它们初始化图表中的tf.Variable。这将允许您利用分布式等中的param服务器。

对于这个(以及其他任何内容),我建议你使用新的“核心”估算器tf.estimator.Estimator,因为这会让事情变得更容易。

根据您提供的链接中的答案,并且知道我们希望变量不是常数,我们可以采取方法:

(2)使用feed dict初始化变量,或  (3)从检查点加载变量

我将首先介绍选项(3),因为它更容易,更好:

model_fn中,只需使用tf.contrib.framework.load_variable调用返回的Tensor初始化变量即可。这需要:

  1. 你有一个有效的TF检查点与你的嵌入
  2. 您知道检查点中嵌入变量的完全限定名称。
  3. 代码非常简单:

    def model_fn(mode, features, labels, hparams):
      embeddings = tf.Variable(tf.contrib.framework.load_variable(
          'gs://my-bucket/word2vec_checkpoints/',
          'a/fully/qualified/scope/embeddings'
      ))
      ....
      return tf.estimator.EstimatorSpec(...)
    

    但是,如果您的嵌入不是由另一个TF模型生成的,那么这种方法对您不起作用,因此选项(2)。

    对于(2),我们需要使用tf.train.Scaffold,它本质上是一个配置对象,它包含了启动tf.Session的所有选项(由于很多原因故意隐藏了估算器)。

    您可以在Scaffold中返回的tf.train.EstimatorSpec中指定model_fn

    我们在model_fn中创建一个占位符,并将其设为 我们的嵌入变量的初始化程序操作,然后通过init_feed_dict传递Scaffold。 e.g。

    def model_fn(mode, features, labels, hparams):
      embed_ph = tf.placeholder(
          shape=[hparams.vocab_size, hparams.embedding_size], 
          dtype=tf.float32)
      embeddings = tf.Variable(embed_ph)
      # Define your model
      return tf.estimator.EstimatorSpec(
          ..., # normal EstimatorSpec args
          scaffold=tf.train.Scaffold(init_feed_dict={embed_ph: my_embedding_numpy_array})
      )
    

    这里发生的是init_feed_dict将在运行时填充embed_ph占位符的值,然后允许embeddings.initialization_op(占位符的赋值)运行。