“tf.nn.embedding_lookup”的工作是什么

时间:2018-05-19 08:00:48

标签: tensorflow text-classification word-embedding embedding-lookup

我正在尝试使用CNN实现嵌入层进行文本分类 嵌入层
with tf.device('/cpu:0'), tf.name_scope("embedding"): self.W = tf.Variable(tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),name="W") self.embedded_chars = tf.nn.embedding_lookup(self.W, self.inputTensor) self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1)

我无法理解tf.nn.embedding_lookup正在工作。

1 个答案:

答案 0 :(得分:0)

  

此函数用于在参数的张量列表中执行并行查找。

它是tf.gather的一般形式。此示例将清除tf.gathertf.nn.embedding_lookup

的工作情况

假设你有一个形状的张量(1),包含字符串。我们称之为params。

<强> PARAMS

|。 0. |。 1. |。 2. |。 3. |。 4. |。 5. | &lt; = index

|。 A1。 |。 a2。 |。 A3。 | A4。 |。 A5。 | a6 | &lt; = values

设Id是int32或int64的另一个张量

<强> IDS

[2,3]

然后这个函数将params中这些索引处的值作为另一个张量返回。

在上述情况下,它返回。 [a3,a4]

This image should make it clear

所以在上面的例子中,self.W在self.InputTensor指向的索引处的值由tf.nn.embedding_lookup函数提取。