Tensorflow包含函数tf.train.ExponentialMovingAverage
,它允许我们对参数应用移动平均值,我发现这些参数对于稳定模型的测试非常有用。
话虽如此,我发现将它应用于一般模型有点令人烦恼。到目前为止,我最成功的方法(如下所示)是编写一个函数装饰器,然后将我的整个NN放在一个函数中。
然而,这有几个缺点。首先,它复制整个图形,其次,我需要在函数内定义我的NN。
有更好的方法吗?
def ema_wrapper(is_training, decay=0.99):
"""Use Exponential Moving Average of parameters during testing.
Parameters
----------
is_training : bool or `tf.Tensor` of type bool
EMA is applied if ``is_training`` is False.
decay:
Decay rate for `tf.train.ExponentialMovingAverage`
"""
def function(fun):
@functools.wraps(fun)
def fun_wrapper(*args, **kwargs):
# Regular call
with tf.variable_scope('ema_wrapper', reuse=False) as scope:
result_train = fun(*args, **kwargs)
# Set up exponential moving average
ema = tf.train.ExponentialMovingAverage(decay=decay)
var_class = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope.name)
ema_op = ema.apply(var_class)
# Add to collection so they are updated
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, ema_op)
# Getter for the variables with EMA applied
def ema_getter(getter, name, *args, **kwargs):
var = getter(name, *args, **kwargs)
ema_var = ema.average(var)
return ema_var if ema_var else var
# Call with EMA applied
with tf.variable_scope('ema_wrapper', reuse=True,
custom_getter=ema_getter):
result_test = fun(*args, **kwargs)
# Return the correct version depending on if we're training or not
return tf.cond(is_training,
lambda: result_train, lambda: result_test)
return fun_wrapper
return function
使用示例:
@ema_wrapper(is_training)
def neural_network(x):
# If is_training is False, we will use an EMA of a instead
a = tf.get_variable('a', [], tf.float32)
return a * x
答案 0 :(得分:10)
您可以使用将EMA变量中的值传输到原始变量的操作:
import tensorflow as tf
# Make model...
minimize_op = ...
model_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
# Make EMA object and update interal variables after optimization step
ema = tf.train.ExponentialMovingAverage(decay=decay)
with tf.control_dependencies([minimize_op]):
train_op = ema.apply(model_vars)
# Transfer EMA values to original variables
retrieve_ema_weights_op = tf.group(
[tf.assign(var, ema.average(var)) for var in model_vars])
with tf.Session() as sess:
# Do training
while ...:
sess.run(train_op, ...)
# Copy EMA values to weights
sess.run(retrieve_ema_weights_op)
# Test model with EMA weights
# ...
编辑:
我做了一个更长的版本,可以在列车和测试模式之间切换变量备份:
import tensorflow as tf
# Make model...
minimize_op = ...
model_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
is_training = tf.get_variable('is_training', shape=(), dtype=tf.bool,
initializer=tf.constant_initializer(True, dtype=tf.bool))
# Make EMA object and update internal variables after optimization step
ema = tf.train.ExponentialMovingAverage(decay=decay)
with tf.control_dependencies([minimize_op]):
train_op = ema.apply(model_vars)
# Make backup variables
with tf.variable_scope('BackupVariables'):
backup_vars = [tf.get_variable(var.op.name, dtype=var.value().dtype, trainable=False,
initializer=var.initialized_value())
for var in model_vars]
def ema_to_weights():
return tf.group(*(tf.assign(var, ema.average(var).read_value())
for var in model_vars))
def save_weight_backups():
return tf.group(*(tf.assign(bck, var.read_value())
for var, bck in zip(model_vars, backup_vars)))
def restore_weight_backups():
return tf.group(*(tf.assign(var, bck.read_value())
for var, bck in zip(model_vars, backup_vars)))
def to_training():
with tf.control_dependencies([tf.assign(is_training, True)]):
return restore_weight_backups()
def to_testing():
with tf.control_dependencies([tf.assign(is_training, False)]):
with tf.control_dependencies([save_weight_backups()]):
return ema_to_weights()
switch_to_train_mode_op = tf.cond(is_training, lambda: tf.group(), to_training)
switch_to_test_mode_op = tf.cond(is_training, to_testing, lambda: tf.group())
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
# Unnecessary, since it begins in training mode, but unharmful
sess.run(switch_to_train_mode_op)
# Do training
while ...:
sess.run(train_op, ...)
# To test mode
sess.run(switch_to_test_mode_op)
# Switching multiple times should not overwrite backups
sess.run(switch_to_test_mode_op)
# Test model with EMA weights
# ...
# Back to training mode
sess.run(switch_to_train_mode_op)
# Keep training...