稍微编辑过的tensorflow官方示例无法正常运行

时间:2018-09-26 12:34:11

标签: python tensorflow graph

我正在通过分布式张量流处理word2vec。出于兼容的原因,只需将正式word2vec略微编辑为Model kinda编码架构即可。

代码段如下:

def build():
    self.global_step = tf.train.get_or_create_global_step()
    with tf.variable_scope("weights", partitioner=partitioner):

        self.embeddings = tf.get_variable(name="embeddings", shape=(self.vocab_size, self.embedding_size), initializer=tf.random_uniform_initializer(minval=-1.0, maxval=1.0))
        self.nce_weights = tf.get_variable(name="nce_weights", shape=(self.vocab_size, self.embedding_size), initializer=tf.truncated_normal_initializer(stddev=1.0/math.sqrt(self.embedding_size)))
        self.bias = tf.get_variable(name="bias", shape=(self.vocab_size), initializer=tf.zeros_initializer())

    self.embeded = tf.nn.embedding_lookup(self.embeddings, inputs, partition_strategy='div')
    print("lables: ", self.labels)

    self.loss = tf.reduce_mean(
        tf.nn.nce_loss(
            weights = self.nce_weights,
            biases = self.bias,
            labels = self.labels,
            inputs = self.embeded,
            num_sampled = self.num_sampled,
            num_classes = self.vocab_size,
            partition_strategy="div"
        )
    )

    self.optimizer = tf.train.GradientDescentOptimizer(self.learning_rate).minimize(self.loss, global_step=self.global_step)

    # evaluate
    normized = tf.sqrt(tf.reduce_sum(tf.square(self.embeddings), 1, keepdims=True))
    normallized_embeddings = self.embeddings / normized
    valid_data = np.r_[1:5]
    self.valid_size = len(valid_data)
    evaluate_examples = tf.constant(valid_data)
    valid_embeddings = tf.nn.embedding_lookup(normallized_embeddings, evaluate_examples)
    self.similarity = tf.matmul(valid_embeddings, normallized_embeddings,  transpose_b=True)

火车方法:

def train(args):
    loss, _, global_step, embs = session.run([self.loss, self.optimizer, self.global_step, self.embeddings])
    print(embs)

培训:

def main():
    model = Word2vec(args)
    model.build() # call the method above to build the graph
    tf.global_variables_initializer()
    with tf.Session() as sess:
        while num_step < upperboud:
             model.train(sess)

我在训练过程中打印出评估结果,发现一直没有变化,但是nce_weights在变化。并且global_step和local_step在增加。不知道哪里出了问题,任何人都可以帮助指出吗?谢谢

0 个答案:

没有答案