关于word2vec`generate_batch()`如何工作的澄清?

时间:2018-06-13 14:54:25

标签: python tensorflow

我一直在努力了解如何将它应用到我的测试和数据集中(我发现github上的张量流代码太复杂而且不是很简单)。

我将使用skip-gram模型。这是我写的代码。我想要一个非神秘的解释,说明发生了什么以及我需要做些什么来使这项工作。

def generate_batch(self):
    inputs = []
    labels = []
    for i,phrase in enumerate(self.training_phrases): # training_phrases look like this: ['I like that cat', '...', ..]
        array_list = utils.skip_gram_tokenize(phrase) # This transforms a sentence into an array of arrays of numbers representing the sentence, ex. [[181, 152], [152, 165], [165, 208], [208, 41]]
        for array in array_list:
            inputs.append(array) # I noticed that this is useless, I could just do inputs = array_list

    return inputs, labels

这就是我现在所处的位置。从tensorflow在github上提供的generate_batch(),我可以看到它返回inputs, labels

我假设输入是跳过克的数组,但什么是标签?我该如何生成它们?

另外,我看到它实现了batch_size,我怎么能这样做(我假设我必须将数据分成更小的部分,但是它是如何工作的?我将数据放入数组中?)。

关于batch_size,如果批量大小为16,但数据仅提供130个输入会发生什么?我是否定期进行8批次,然后是2个输入的小批量?

1 个答案:

答案 0 :(得分:1)

对于skip-gram,您需要将input-label对作为current word及其context word。每个context word的{​​{1}}在文本短语的窗口中定义。

考虑以下文字短语:input word对于"Here's looking at you kid".的窗口,对于当前单词3,您有两个上下文单词at和{{1} }。因此looking对是you,您可以将它们转换为数字表示。

在上面的代码中,数组列表示例如下:input label,表示当前单词及其上下文是为下一个单词而不是前一个单词定义的。

架构如下所示:

enter image description here

现在你已经生成了这些对,分批获取并训练它们。批量不均匀,但确保您的损失为{at, looking}, {at, you}而不是ex. [[181, 152], [152, 165], [165, 208], [208, 41]]