从过去的几天开始,我一直试图在代码https://github.com/tensorflow/models/blob/master/tutorials/embedding/word2vec.py#L28中找出执行流程。
我理解负抽样和损失函数背后的逻辑,但我对列车函数内的执行流程感到困惑,特别是涉及_train_thread_body
函数时。我对while和if循环(影响是什么)以及并发相关部分感到困惑。如果有人可以给出一个不错的解释,那么在向下投票之前会很棒。
答案 0 :(得分:1)
此sample code被称为" 多线程 word2vec迷你批量跳过 - 克模型",这就是为什么它使用多个独立线程进行培训。 Word2Vec也可以使用单个线程进行训练,但本教程表明word2vec在并行完成时计算速度更快。
输入,标签和纪元张量由本地word2vec.skipgram_word2vec
函数提供,该函数在tutorials/embedding/word2vec_kernels.cc
文件中实现。在那里你可以看到current_epoch
是一个张量,一旦整个句子语句被处理就会更新。
您提出的方法实际上非常简单:
def _train_thread_body(self):
initial_epoch, = self._session.run([self._epoch])
while True:
_, epoch = self._session.run([self._train, self._epoch])
if epoch != initial_epoch:
break
首先,它计算当前时期,然后调用训练直到epoch
增加。这意味着运行此方法的所有线程都将使完全一个训练时代。每个线程与其他线程并行执行一步。
self._train
是一个优化损失函数的操作(请参阅optimize
方法),该函数根据当前examples
和labels
计算(请参阅build_graph
方法)。这些张量的确切值再次使用本机代码,即NextExample
。实质上,word2vec.skipgram_word2vec
的每次调用都会提取一组示例和标签,这些示例和标签构成了优化函数的输入。希望,现在让它更清晰。
顺便说一下,这个模型在训练中使用NCE loss,而不是负抽样。