如何在Tensorflow中使用指数移动平均线

时间:2018-03-07 09:09:44

标签: python tensorflow moving-average

问题

Tensorflow包含函数tf.train.ExponentialMovingAverage,它允许我们对参数应用移动平均值,我发现这些参数对于稳定模型的测试非常有用。

话虽如此,我发现将它应用于一般模型有点令人烦恼。到目前为止,我最成功的方法(如下所示)是编写一个函数装饰器,然后将我的整个NN放在一个函数中。

然而,这有几个缺点。首先,它复制整个图形,其次,我需要在函数内定义我的NN。

有更好的方法吗?

当前实施

def ema_wrapper(is_training, decay=0.99):
    """Use Exponential Moving Average of parameters during testing.

    Parameters
    ----------
    is_training : bool or `tf.Tensor` of type bool
        EMA is applied if ``is_training`` is False.
    decay:
        Decay rate for `tf.train.ExponentialMovingAverage`
    """
    def function(fun):
        @functools.wraps(fun)
        def fun_wrapper(*args, **kwargs):
            # Regular call
            with tf.variable_scope('ema_wrapper', reuse=False) as scope:
                result_train = fun(*args, **kwargs)

            # Set up exponential moving average
            ema = tf.train.ExponentialMovingAverage(decay=decay)
            var_class = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          scope.name)
            ema_op = ema.apply(var_class)

            # Add to collection so they are updated
            tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, ema_op)

            # Getter for the variables with EMA applied
            def ema_getter(getter, name, *args, **kwargs):
                var = getter(name, *args, **kwargs)
                ema_var = ema.average(var)
                return ema_var if ema_var else var

            # Call with EMA applied
            with tf.variable_scope('ema_wrapper', reuse=True,
                                   custom_getter=ema_getter):
                result_test = fun(*args, **kwargs)

            # Return the correct version depending on if we're training or not
            return tf.cond(is_training,
                           lambda: result_train, lambda: result_test)
        return fun_wrapper
    return function

使用示例:

@ema_wrapper(is_training)
def neural_network(x):
    # If is_training is False, we will use an EMA of a instead
    a = tf.get_variable('a', [], tf.float32)
    return a * x

1 个答案:

答案 0 :(得分:10)

您可以使用将EMA变量中的值传输到原始变量的操作:

import tensorflow as tf

# Make model...
minimize_op = ...
model_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
# Make EMA object and update interal variables after optimization step
ema = tf.train.ExponentialMovingAverage(decay=decay)
with tf.control_dependencies([minimize_op]):
    train_op = ema.apply(model_vars)

# Transfer EMA values to original variables
retrieve_ema_weights_op = tf.group(
    [tf.assign(var, ema.average(var)) for var in model_vars])

with tf.Session() as sess:
    # Do training
    while ...:
        sess.run(train_op, ...)
    # Copy EMA values to weights
    sess.run(retrieve_ema_weights_op)
    # Test model with EMA weights
    # ...

编辑:

我做了一个更长的版本,可以在列车和测试模式之间切换变量备份:

import tensorflow as tf

# Make model...
minimize_op = ...
model_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

is_training = tf.get_variable('is_training', shape=(), dtype=tf.bool,
                              initializer=tf.constant_initializer(True, dtype=tf.bool))

# Make EMA object and update internal variables after optimization step
ema = tf.train.ExponentialMovingAverage(decay=decay)
with tf.control_dependencies([minimize_op]):
    train_op = ema.apply(model_vars)
# Make backup variables
with tf.variable_scope('BackupVariables'):
    backup_vars = [tf.get_variable(var.op.name, dtype=var.value().dtype, trainable=False,
                                   initializer=var.initialized_value())
                   for var in model_vars]

def ema_to_weights():
    return tf.group(*(tf.assign(var, ema.average(var).read_value())
                     for var in model_vars))
def save_weight_backups():
    return tf.group(*(tf.assign(bck, var.read_value())
                     for var, bck in zip(model_vars, backup_vars)))
def restore_weight_backups():
    return tf.group(*(tf.assign(var, bck.read_value())
                     for var, bck in zip(model_vars, backup_vars)))

def to_training():
    with tf.control_dependencies([tf.assign(is_training, True)]):
        return restore_weight_backups()

def to_testing():
    with tf.control_dependencies([tf.assign(is_training, False)]):
        with tf.control_dependencies([save_weight_backups()]):
            return ema_to_weights()

switch_to_train_mode_op = tf.cond(is_training, lambda: tf.group(), to_training)
switch_to_test_mode_op = tf.cond(is_training, to_testing, lambda: tf.group())

init_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init_op)
    # Unnecessary, since it begins in training mode, but unharmful
    sess.run(switch_to_train_mode_op)
    # Do training
    while ...:
        sess.run(train_op, ...)
    # To test mode
    sess.run(switch_to_test_mode_op)
    # Switching multiple times should not overwrite backups
    sess.run(switch_to_test_mode_op)
    # Test model with EMA weights
    # ...
    # Back to training mode
    sess.run(switch_to_train_mode_op)
    # Keep training...