通过获取Tensorflow中所有词嵌入的均值来获取句子嵌入?

时间:2018-12-31 19:58:35

标签: python tensorflow

这是我的代码,用于使用tf.string类型分割输入Tensor,并使用预先训练的GloVe模型提取其每个词嵌入。但是,关于 cond 实现,我得到了不必要的错误。我想知道是否有一种更干净的方法来获取字符串张量中所有单词的嵌入。

# Take out the words
target_words = tf.string_split([target_sentence], delimiter=" ")

# Tensorflow parallel while loop variable, condition and body
i = tf.constant(0, dtype=tf.int32)
cond = lambda self, i: tf.less(x=tf.cast(i, tf.int32), y=tf.cast(tf.shape(target_words)[0], tf.int32))
sentence_mean_embedding = tf.Variable([], trainable=False)

def body(i, sentence_mean_embedding):
    sentence_mean_embedding = tf.concat(1, tf.nn.embedding_lookup(params=tf_embedding, ids=tf.gather(target_words, i)))

    return sentence_mean_embedding

embedding_sentence = tf.reduce_mean(tf.while_loop(cond, body, [i, sentence_mean_embedding]))

1 个答案:

答案 0 :(得分:2)

使用index_table_from_fileDataset API的方法更简洁。

首先,创建自己的tf.Dataset(我假设我们有两个带有任意标签的句子):

sentence = tf.constant(['this is first sentence', 'this is second sentence'])
labels = tf.constant([1, 0])
dataset = tf.data.Dataset.from_tensor_slices((sentence, labels))

第二,创建一个vocab.txt文件,该文件中每一行的编号都映射到Glove嵌入中的相同索引。例如,如果vocab.txt中“手套”中的第一个词汇“不存在”,则第一行应“不存在”,依此类推。为简单起见,假设我们的vocab.txt包含以下单词:

first
is
test
this
second
sentence

然后基于here定义一个表,其目标是将每个单词转换为特定的id:

table = tf.contrib.lookup.index_table_from_file(vocabulary_file="vocab.txt", num_oov_buckets=1)
dataset = dataset.map(lambda x, y: (tf.string_split([x]).values, y))
dataset = dataset.map(lambda x, y: (tf.cast(table.lookup(x), tf.int32), y))

dataset = dataset.batch(1)

最后,基于this answer,通过使用nn.embedding_lookup()将每个句子转换为嵌入:

glove_weights = tf.get_variable('embed', shape=embedding.shape, initializer=initializer=tf.constant_initializer(embedding), trainable=False)

iterator = dataset.make_initializable_iterator()
x, y = iterator.get_next()

embedding = tf.nn.embedding_lookup(glove_weights, x)
sentence = tf.reduce_mean(embedding, axis=1)

在急切模式下完成代码:

import tensorflow as tf

tf.enable_eager_execution()

sentence = tf.constant(['this is first sentence', 'this is second sentence'])
labels = tf.constant([1, 0])

dataset = tf.data.Dataset.from_tensor_slices((sentence, labels))
table = tf.contrib.lookup.index_table_from_file(vocabulary_file="vocab.txt", num_oov_buckets=1)
dataset = dataset.map(lambda x, y: (tf.string_split([x]).values, y))
dataset = dataset.map(lambda x, y: (tf.cast(table.lookup(x), tf.int32), y))

dataset = dataset.batch(1)

glove_weights = tf.get_variable('embed', shape=(10000, 300), initializer=tf.truncated_normal_initializer())

for x, y in dataset:
    embedding = tf.nn.embedding_lookup(glove_weights, x)
    sentence = tf.reduce_mean(embedding, axis=1)
    print(sentence.shape)