当我们使用ExponentialMovingAverage训练神经网络时,模型是否默认使用衰减变量?

时间:2018-04-10 01:27:44

标签: tensorflow

当我们使用tf.train.ExponentialMovingAverage.apply(var)维护变量的移动平均值时,如果我们更新变量(如使用tf.assign)来获取衰减变量,我们将使用tf.train.ExponentialMovingAverage。 average(var),但是如果我们直接通过tf.Session.run(var)得到变量,我们将得到没有衰减的变量。

例如:

import tensorflow as tf;  

v1 = tf.Variable(0, dtype=tf.float32)  
ema = tf.train.ExponentialMovingAverage(0.99)  
maintain_average = ema.apply([v1])  

with tf.Session() as sess:  
    init = tf.initialize_all_variables()  
    sess.run(init)  

    print(sess.run([v1, ema.average(v1)])) 
    # Out:[0.0, 0.0]

    sess.run(tf.assign(v1, 5))
    sess.run(maintain_average)  
    print(sess.run([v1, ema.average(v1)]))
    # Out: [10.0, 0.14949986]

因此,当我们使用ExponentialMovingAverage训练神经网络时,模型是否默认使用tf.train.ExponentialMovingAverage.average()的衰减变量?

更具体的例子:

image_tensor = tf.placeholder(tf.float32,
                                  [BATCH_SIZE,IMAGE_SIZE,IMAGE_SIZE,IMAGE_CHANNELS],
                                  'image-tensor')
    label_tensor = tf.placeholder(tf.int32,
                                  [None,10],
                                  'label-tensor')
    net_output = creat_net(image_tensor)
    #suppose creat_net() have build a neural network
    global_step = tf.Variable(0, trainable=False)
    cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=net_output, labels=label_tensor))
    loss = cross_entropy
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
    ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
    with tf.control_dependencies([train_step]):
        training_op = ema.apply(tf.trainable_variables())

因此,当我运行training_op来训练网络时,网络将默认使用平均值,或者我需要额外的代码来使用衰减变量?换句话说,GradientDescentOptimizer将使用真值或衰减值来计算下一步的损失?

1 个答案:

答案 0 :(得分:0)

v1是一个具有自己值的变量(在您的情况下为10.0)。

tf.train.ExponentialMovingAverage在内部维护一个变量,每次调用average时都会更新。

每当您使用新输入调用average时,您只需计算指数移动平均线的下一个时间步长(因此只需更改tf.train.ExponentialMovingAverage op的私有变量)而无需更改输入变量。