如何从TensorFlow中的嵌入矩阵中获取随机嵌入?

时间:2018-05-07 05:20:58

标签: tensorflow word-embedding

假设我为训练集中的某些单词创建了一个嵌入矩阵,如下所示:

with train_graph.as_default():
    word_embedding = tf.Variable(tf.random_uniform((n_vocab, 256), 0, 1), name='word_embedding')

我可以使用tf.nn.embedding_lookup(word_embedding, inputs)在培训期间获取特定批次中的字词的嵌入。但是,如何进行随机嵌入并将其用于matmul操作?

1 个答案:

答案 0 :(得分:0)

以下是使用热切执行的示例:

# Create the variable
x = tfe.Variable(tf.ones([10, 20]))
# Get random integer in [0, 20)
y = tf.random_uniform(shape=(), maxval=20, dtype=tf.int32)
# Get the y'th row from x.
z = tf.gather(x, y, axis=1)

对于常规TensorFlow,只需替换变量创建。