Tensorflow tf.nn.embedding_lookup

时间:2017-11-25 11:33:32

标签: tensorflow word-embedding embedding-lookup

tf.nn.embedding_lookup中是否有一个小型神经网络? 当我训练一些数据时,相同索引的值正在改变。 它也受过训练吗?我正在训练我的模特

我检查了官方的embedding_lookup代码但是我看不到任何用于列车嵌入参数的tf.Variables。 但是当我打印所有tf.Variables时,我可以找到一个嵌入范围内的变量

谢谢。

1 个答案:

答案 0 :(得分:2)

是的,学习了嵌入。您可以将tf.nn.embedding_lookup操作视为更有效地执行以下矩阵乘法:

import tensorflow as tf
import numpy as np

NUM_CATEGORIES, EMBEDDING_SIZE = 5, 3
y = tf.placeholder(name='class_idx', shape=(1,), dtype=tf.int32)

RS = np.random.RandomState(42)
W_em_init = RS.randn(NUM_CATEGORIES, EMBEDDING_SIZE)
W_em = tf.get_variable(name='W_em',
                       initializer=tf.constant_initializer(W_em_init),
                       shape=(NUM_CATEGORIES, EMBEDDING_SIZE))

# Using tf.nn.embedding_lookup
y_em_1 = tf.nn.embedding_lookup(W_em, y)

# Using multiplication
y_one_hot = tf.one_hot(y, depth=NUM_CATEGORIES)
y_em_2 = tf.matmul(y_one_hot, W_em)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
sess.run([y_em_1, y_em_2], feed_dict={y: [1.0]})
# [array([[ 1.5230298 , -0.23415338, -0.23413695]], dtype=float32),
#  array([[ 1.5230298 , -0.23415338, -0.23413695]], dtype=float32)]

变量W_em将以完全相同的方式进行培训,无论您使用的是y_em_1还是y_em_2;不过,y_em_1可能更有效率。