Tensorflow实现word2vec

时间:2016-06-29 22:32:34

标签: python tensorflow word2vec

Tensorflow教程here指的是他们在github here上可以找到的基本实现,其中Tensorflow作者使用Skipgram模型实现word2vec向量嵌入训练/评估。

我的问题是关于generate_batch()函数中(目标,上下文)对的实际生成。

On this line Tensorflow作者随机抽取来自"中心的附近目标指数"单词滑动窗口中的单词索引。

但是,他们also keep a data structure targets_to_avoid首先添加"中心"上下文单词(当然我们不想抽样)但在我们添加之后还有其他单词。

我的问题如下:

  1. 为什么要从这个单词周围的滑动窗口进行采样,为什么不只是有一个循环并使用它们而不是采样?他们会担心word2vec_basic.py(他们的"基本"实现)中的性能/内存似乎很奇怪。
  2. 无论1)的答案是什么,为什么他们采样并跟踪他们用targets_to_avoid选择的内容?如果他们想要真正随机,他们会使用替换选择,如果他们想要确保他们获得所有选项,他们应该只使用一个循环并首先获得所有选项!
  3. 内置的tf.models.embedding.gen_word2vec也是这样的吗?如果是这样我在哪里可以找到源代码? (无法在Github回购中找到.py文件)
  4. 谢谢!

2 个答案:

答案 0 :(得分:3)

我尝试了你提出的生成批次的方法 - 有一个循环并使用整个跳过窗口。结果是:

<强> 1。更快生成批次

批量大小为128,跳过窗口为5

  • 通过逐个循环数据生成批次每10,000次批次 0.73s
  • 使用教程代码生成批次,num_skips=2每10,000批次生成 3.59s

<强> 2。过度拟合的危险性更高

保持教程代码的其余部分不变,我用两种方式训练模型并记录每2000步的平均损失:

enter image description here

这种模式反复发生。它表明,每个单词使用10个样本而不是2个样本会导致过度拟合。

以下是我用于生成批次的代码。它取代了教程的generate_batch函数。

data_index = 0

def generate_batch(batch_size, skip_window):
    global data_index
    batch = np.ndarray(shape=(batch_size), dtype=np.int32)  # Row
    labels = np.ndarray(shape=(batch_size, 1), dtype=np.int32)  # Column

    # For each word in the data, add the context to the batch and the word to the labels
    batch_index = 0
    while batch_index < batch_size:
        context = data[get_context_indices(data_index, skip_window)]

        # Add the context to the remaining batch space
        remaining_space = min(batch_size - batch_index, len(context))
        batch[batch_index:batch_index + remaining_space] = context[0:remaining_space]
        labels[batch_index:batch_index + remaining_space] = data[data_index]

        # Update the data_index and the batch_index
        batch_index += remaining_space
        data_index = (data_index + 1) % len(data)

    return batch, labels

编辑: get_context_indices是一个简单的函数,它返回skip_window中data_index周围的索引切片。有关详细信息,请参阅slice() documentation

答案 1 :(得分:1)

有一个名为num_skips的参数,它表示从单个窗口生成的(输入,输出)对的数量:[skip_window target skip_window]。所以num_skips限制了我们用作输出词的上下文词的数量。这就是generate_batch函数assert num_skips <= 2*skip_window的原因。代码只是随机选取num_skip个上下文单词来构建带目标的训练对。 但我不知道num_skips如何影响表现。