我通过get_or_create_global_step创建了global_step变量,然后使用get_global_step访问它,我在训练期间看到我访问的值根本没有变化。以下是相关代码
with tf.device(self.device):
...
loss = model.loss(....)
global_step = tf.train.get_or_create_global_step()
update = model.update(loss, global_step, self.lrate, self.grad_clip)
init_op = tf.global_variables_initializer()
...
with tf.train.MonitoredTrainingSession(master=self.server.target, is_chief=(self.task_index == 0), config=config, hooks=hooks) as mon_sess:
mon_sess.run(init_op)
_, lossVal, step = mon_sess.run([update, loss, tf.train.get_global_step()])
print('step %d, train loss = %f' % (step, lossVal))
...
和model是Model类的一个实例,其实现的更新函数如下所示
class Model(object):
def __init__():
...
def update(self, loss, global_step, learning_rate, grad_clip):
optimizer = tf.train.AdamOptimizer(learning_rate)
grads_and_vars = optimizer.compute_gradients(loss=loss)
...
update_op = optimizer.apply_gradients(grads_and_vars=clipped_grads_and_vars,global_step=global_step, name='apply_gradients')
return update_op
以上代码打印出来
step 0, train loss = ...
step 0, train loss = ...
step 0, train loss = ...
任何人都可以提供帮助吗?
谢谢!