Tensorflow:RNN示例,单词嵌入来自何处

时间:2017-05-12 17:02:57

标签: python tensorflow

这里有一个Tensorflow LSTM示例:

https://github.com/tensorflow/models/blob/master/tutorials/rnn/ptb/ptb_word_lm.py

我试图了解嵌入这个词的来源。

看,这是代码:

with tf.device("/cpu:0"):
    embedding = tf.get_variable(
        "embedding", [vocab_size, size], dtype=data_type())
    inputs = tf.nn.embedding_lookup(embedding, input_.input_data)

我对embedding变量应该保留的内容有所了解(正如在本例中所解释的那样:https://www.tensorflow.org/tutorials/word2vec。但是必须有一些魔力来完成工作(训练嵌入模型等。)

我在项目代码中没有看到类似内容。我找不到任何可以生成简单的单热编码向量的东西。它只是用整数ID替换单词,然后在读取器代码中重新整形数据(https://github.com/tensorflow/models/blob/master/tutorials/rnn/ptb/reader.py)。

我错过了什么?如果这显然是显而易见的话,我真的很抱歉。

1 个答案:

答案 0 :(得分:1)

我不完全确定,但这是我所理解的:

我认为嵌入一个可训练的张量。 tf.get_variable()获取带有这些参数的现有变量,如果没有,则创建一个新变量。

如果初始化程序为None,则将使用glorot_uniform_initializer。

根据词汇大小,我们初始化大嵌入矩阵,让它为我们的词汇找到最佳嵌入值。