Tensorflow:global_step值不会改变(总是0)

时间:2018-05-01 23:22:09

标签: tensorflow

我通过get_or_create_global_step创建了global_step变量,然后使用get_global_step访问它,我在训练期间看到我访问的值根本没有变化。以下是相关代码

with tf.device(self.device):
    ...
    loss = model.loss(....)
    global_step = tf.train.get_or_create_global_step()
    update = model.update(loss, global_step, self.lrate, self.grad_clip)
    init_op = tf.global_variables_initializer()
    ...

with tf.train.MonitoredTrainingSession(master=self.server.target, is_chief=(self.task_index == 0), config=config, hooks=hooks) as mon_sess:
    mon_sess.run(init_op)
    _, lossVal, step = mon_sess.run([update, loss, tf.train.get_global_step()])
    print('step %d, train loss = %f' % (step, lossVal))
    ...

和model是Model类的一个实例,其实现的更新函数如下所示

class Model(object):
    def __init__():
    ...
    def update(self, loss, global_step, learning_rate, grad_clip):
        optimizer = tf.train.AdamOptimizer(learning_rate)
        grads_and_vars = optimizer.compute_gradients(loss=loss)
        ...
        update_op = optimizer.apply_gradients(grads_and_vars=clipped_grads_and_vars,global_step=global_step, name='apply_gradients')
        return update_op  

以上代码打印出来

step 0, train loss = ...
step 0, train loss = ...
step 0, train loss = ...

任何人都可以提供帮助吗?

谢谢!

0 个答案:

没有答案