word2vec tensorflow中的执行流程

时间:2017-11-13 05:26:48

标签: python multithreading tensorflow parallel-processing word2vec

从过去的几天开始,我一直试图在代码https://github.com/tensorflow/models/blob/master/tutorials/embedding/word2vec.py#L28中找出执行流程。

我理解负抽样和损失函数背后的逻辑,但我对列车函数内的执行流程感到困惑,特别是涉及_train_thread_body函数时。我对while和if循环(影响是什么)以及并发相关部分感到困惑。如果有人可以给出一个不错的解释,那么在向下投票之前会很棒。

1 个答案:

答案 0 :(得分:1)

sample code被称为" 多线程 word2vec迷你批量跳过 - 克模型",这就是为什么它使用多个独立线程进行培训。 Word2Vec也可以使用单个线程进行训练,但本教程表明word2vec在并行完成时计算速度更快。

输入,标签和纪元张量由本地word2vec.skipgram_word2vec函数提供,该函数在tutorials/embedding/word2vec_kernels.cc文件中实现。在那里你可以看到current_epoch是一个张量,一旦整个句子语句被处理就会更新。

您提出的方法实际上非常简单:

def _train_thread_body(self):
  initial_epoch, = self._session.run([self._epoch])
  while True:
    _, epoch = self._session.run([self._train, self._epoch])
    if epoch != initial_epoch:
      break

首先,它计算当前时期,然后调用训练直到epoch增加。这意味着运行此方法的所有线程都将使完全一个训练时代。每个线程与其他线程并行执行一步。

self._train是一个优化损失函数的操作(请参阅optimize方法),该函数根据当前exampleslabels计算(请参阅build_graph方法)。这些张量的确切值再次使用本机代码,即NextExample。实质上,word2vec.skipgram_word2vec的每次调用都会提取一组示例和标签,这些示例和标签构成了优化函数的输入。希望,现在让它更清晰。

顺便说一下,这个模型在训练中使用NCE loss,而不是负抽样。