TensorFlow'global_step'变量未获得指数衰减的更新

时间:2017-03-13 01:05:07

标签: python tensorflow neural-network

我有一个网络,我正在使用学习率的指数衰减。为此,我将跟踪一个'global_step'TF变量,该变量在处理的每个批次中增加1。然而,看起来在现实中,它并没有真正得到更新。这是代码。

...
global_step = tf.Variable(0, trainable=False, name='global_step')
starter_learning_rate = 0.01
learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step, 1000, 0.50)

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
    optm = tf.train.AdamOptimizer(learning_rate).minimize(cost, global_step=global_step)
init = tf.global_variables_initializer()


def train(file):
    global global_step
    for batch in batches:
        global_step += 1
        ...        
    return loss

...
global_step = 0
for epoch in EPOCHS:
    for f in files:
        loss = train(f)

函数和外部的global_step正在更新。但我的学习率并没有改变。当我将汇总附加到我的TF global_step变量时,我发现它保持不变为0。

这里有什么问题?

2 个答案:

答案 0 :(得分:2)

实际上,我还没有看到你在哪里设置 learning_rate 变量,但这是如何使用它的方式:

定义全局步变量

global_step = tf.Variable(0)

使用不同的params

定义学习速率变化的方式
learning_rate = tf.train.exponential_decay(0.1, global_step, 500, 0.7, staircase=True)

将它们传递给优化器

optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)

答案 1 :(得分:1)

这里有两个问题。

  1. 您不应该自己增加变量
  2. 你实际上并没有手动递增global_step,即使它看起来像是你。
  3. 第1期

    根据tf.train.AdamOptimizer的文档,调用minimize()是调用compute_gradients()apply_gradients()的捷径。您实际返回到optm变量的内容是:

      

    应用[...]渐变的Operation。如果global_step不是无,则该操作也会增加global_step

    这意味着当您致电global_step时,tf.Variablesess.run(optm))存储的值会增加。

    第2期

    在给定代码的第一行之后,您有一个名为global_step的变量,它是一个tf.Variable对象。重要的是,它没有数值;它只是一个对象的引用,当运行sess.run提供数字值

    为了使图形构建更加方便,Tensorflow允许这样的操作:

    a = tf.constant(1)
    b = a + 2
    

    此时,变量b将是新Tensor对象。我们可以运行sess.run(b)并获取实际值(当然,在初始化之后),但b对象,而不是值。当您运行global_step += 1时,它会创建一个新Tensor对象,当您sess.run时,它会进行一些计算并返回一个数字。

    因此,在global global_step,您仍然可以引用tf.Variable张量,但在第一次循环后,您的global_step将引用: 一个张量,当通过sess.run时,会给你一个在你原来的tf.Variable对象上加1的结果。

    在第二个循环之后,你的global_step指的是一个张量,它会给你在你原来的tf.Variable对象上加1的结果加1的结果。

    当您循环时,您正在添加操作并引用新结果,但从未实际更改为tf.Variable对象存储的值。这就是为什么当你运行sess.run(global_step)时,你会得到你期望的数字,而实际的变量值永远不会改变。