最近,我尝试使用多个GPU来加速培训。但是我在批量标准化方面遇到了一些问题。具体来说,当我使用ExponentialMovingAverage获得平均批量均值和变量时,准确性很差。
我已经尝试了几种方法(函数)来实现批量标准化,如下所示。代码的其余部分是相同的,我只是尝试了不同的批量规范化功能。使用2个GPU时,方法2-4运行良好,但方法1对测试数据集的准确性非常差。当我切换到仅使用1个GPU时,所有方法都能很好地工作。
数据集为CIFAR10,批量大小为128.当使用2个GPU时,每个GPU处理64个样本,然后平均每个GPU的渐变,就像张量流教程CIFAR10 multi-gpu一样。
我的tensorflow版本是1.1.0,python版本是2.7,OS是ubuntu 16.04。 CUDA版本是7.5.17。
我的问题是为什么使用tf.train.ExponentialMovingAverage(方法1)在使用多个GPU时效果如此糟糕?我真的很困惑。
方法1:
def batch_norm_conv(x, n_out = 3, phase_train=True, scope='bn_conv'):
with tf.variable_scope(scope):
beta = tf.get_variable('beta_conv', shape=[n_out], initializer=tf.constant_initializer(0.0))
gamma = tf.get_variable('gamma_conv', shape=[n_out], initializer=tf.constant_initializer(1.0))
batch_mean_temp, batch_var_temp = tf.nn.moments(x, [0,1,2], name='moments')
batch_mean = tf.get_variable('batch_mean', shape=batch_mean_temp.get_shape(), initializer=tf.constant_initializer(0.0), trainable=False)
batch_var = tf.get_variable('batch_var', shape=batch_var_temp.get_shape(), initializer=tf.constant_initializer(0.0), trainable=False)
mean_op = tf.assign(batch_mean, batch_mean_temp)
var_op = tf.assign(batch_var, batch_var_temp)
ema = tf.train.ExponentialMovingAverage(decay=0.5, zero_debias=False)
ema_apply_op = ema.apply([batch_mean, batch_var])
def mean_var_with_update():
with tf.control_dependencies([mean_op, var_op]):
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
mean, var = tf.cond(phase_train,
mean_var_with_update,
lambda: (ema.average(batch_mean), ema.average(batch_var)))
normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-5)
return normed
方法2:
def batch_norm_conv(x, n_out = 3, phase_train=True, scope='bn_conv'):
with tf.variable_scope(scope):
beta = tf.get_variable('beta_conv', shape=[n_out], initializer=tf.constant_initializer(0.0))
gamma = tf.get_variable('gamma_conv', shape=[n_out], initializer=tf.constant_initializer(1.0))
batch_mean, batch_var = tf.nn.moments(x, [0,1,2], name='moments')
print(batch_mean.get_shape())
mean_average = tf.get_variable('mean_average', shape=batch_mean.get_shape(), initializer=tf.constant_initializer(0.0))
var_average = tf.get_variable('var_average', shape=batch_var.get_shape(), initializer=tf.constant_initializer(0.0))
decay=0.5
def mean_var_with_update():
mean_temp = decay * mean_average + (1-decay) * batch_mean
var_temp = decay * var_average + (1-decay) * batch_var
mean_op = tf.assign(mean_average, mean_temp)
var_op = tf.assign(var_average, var_temp)
with tf.control_dependencies([mean_op, var_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
mean, var = tf.cond(phase_train, mean_var_with_update, lambda: (mean_average, var_average))
normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3)
return normed
方法3:
from tensorflow.python.training import moving_averages
def batch_norm_conv(x, n_out = 3, phase_train=True, scope='bn_conv'):
with tf.variable_scope(scope):
beta = tf.get_variable('beta_conv', shape=[n_out], initializer=tf.constant_initializer(0.0))
gamma = tf.get_variable('gamma_conv', shape=[n_out], initializer=tf.constant_initializer(1.0))
batch_mean, batch_var = tf.nn.moments(x, [0,1,2], name='moments')
moving_mean = tf.get_variable('batch_mean', shape=batch_mean.get_shape(), initializer=tf.constant_initializer(0.0), trainable=False)
moving_variance = tf.get_variable('batch_var', shape=batch_var.get_shape(), initializer=tf.constant_initializer(0.0), trainable=False)
update_moving_mean = moving_averages.assign_moving_average(moving_mean,
batch_mean, 0.5, zero_debias=False)
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, batch_var, 0.5, zero_debias=False)
def mean_var_with_update():
with tf.control_dependencies([update_moving_mean, update_moving_variance]):
return tf.identity(batch_mean), tf.identity(batch_var)
mean, var = tf.cond(phase_train,
mean_var_with_update,
lambda: (moving_mean, moving_variance))
normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3)
return normed
方法4:
def batch_norm_conv(x, phase_train=True, scope='bn_conv'):
with tf.variable_scope(scope):
normed = tf.contrib.layers.batch_norm(x,
center=True,
scale=True,
is_training = phase_train,
decay = 0.5,
trainable = True
)
return normed
方法1的结果:
('epoch', 0, 'test accuracy:', 0.12970000132918358)
('epoch', 1, 'test accuracy:', 0.20419999957084656)
('epoch', 2, 'test accuracy:', 0.11649999991059304)
('epoch', 3, 'test accuracy:', 0.12790000066161156)
('epoch', 4, 'test accuracy:', 0.17040000036358832)
('epoch', 5, 'test accuracy:', 0.15139999836683274)
('epoch', 6, 'test accuracy:', 0.13050000220537186)
('epoch', 7, 'test accuracy:', 0.15879999995231628)
('epoch', 8, 'test accuracy:', 0.17370000183582307)
('epoch', 9, 'test accuracy:', 0.17910000011324884)
('epoch', 10, 'test accuracy:', 0.17960000038146973)
('epoch', 11, 'test accuracy:', 0.12400000095367432)
('epoch', 12, 'test accuracy:', 0.13669999763369561)
('epoch', 13, 'test accuracy:', 0.25510000437498093)
('epoch', 14, 'test accuracy:', 0.18769999742507934)
('epoch', 15, 'test accuracy:', 0.16730000004172324)
('epoch', 16, 'test accuracy:', 0.15510000288486481)
('epoch', 17, 'test accuracy:', 0.19639999866485597)
('epoch', 18, 'test accuracy:', 0.24789999574422836)
('epoch', 19, 'test accuracy:', 0.15929999947547913)
('epoch', 20, 'test accuracy:', 0.17439999729394912)
方法2 - 4的结果(它们没有太大区别,所以只发布其中一个):
('epoch', 0, 'test accuracy:', 0.27250000238418581)
('epoch', 1, 'test accuracy:', 0.42709999978542329)
('epoch', 2, 'test accuracy:', 0.50179999470710757)
('epoch', 3, 'test accuracy:', 0.56709998846054077)
('epoch', 4, 'test accuracy:', 0.59760001301765442)
('epoch', 5, 'test accuracy:', 0.66010000705718996)
('epoch', 6, 'test accuracy:', 0.65400000214576726)
('epoch', 7, 'test accuracy:', 0.69880000352859495)
('epoch', 8, 'test accuracy:', 0.69749999642372129)
('epoch', 9, 'test accuracy:', 0.71029999256134035)
('epoch', 10, 'test accuracy:', 0.72619999051094053)
('epoch', 11, 'test accuracy:', 0.72920000553131104)
('epoch', 12, 'test accuracy:', 0.7372000098228455)
('epoch', 13, 'test accuracy:', 0.75380001068115232)
('epoch', 14, 'test accuracy:', 0.74269998073577881)
('epoch', 15, 'test accuracy:', 0.76199999451637268)
('epoch', 16, 'test accuracy:', 0.7636999785900116)
('epoch', 17, 'test accuracy:', 0.76039999723434448)
('epoch', 18, 'test accuracy:', 0.77150000333786006)
('epoch', 19, 'test accuracy:', 0.77920001149177553)
('epoch', 20, 'test accuracy:', 0.79100000858306885)