使用多个GPU时,Tensorflow ExponentialMovingAverage不适合批量标准化?

时间:2017-06-12 03:29:03

标签: tensorflow moving-average multi-gpu batch-normalization

最近,我尝试使用多个GPU来加速培训。但是我在批量标准化方面遇到了一些问题。具体来说,当我使用ExponentialMovingAverage获得平均批量均值和变量时,准确性很差。

我已经尝试了几种方法(函数)来实现批量标准化,如下所示。代码的其余部分是相同的,我只是尝试了不同的批量规范化功能。使用2个GPU时,方法2-4运行良好,但方法1对测试数据集的准确性非常差。当我切换到仅使用1个GPU时,所有方法都能很好地工作。

数据集为CIFAR10,批量大小为128.当使用2个GPU时,每个GPU处理64个样本,然后平均每个GPU的渐变,就像张量流教程CIFAR10 multi-gpu一样。

我的tensorflow版本是1.1.0,python版本是2.7,OS是ubuntu 16.04。 CUDA版本是7.5.17。

我的问题是为什么使用tf.train.ExponentialMovingAverage(方法1)在使用多个GPU时效果如此糟糕?我真的很困惑。

方法1:

def batch_norm_conv(x, n_out = 3, phase_train=True, scope='bn_conv'):
    with tf.variable_scope(scope):
        beta = tf.get_variable('beta_conv', shape=[n_out], initializer=tf.constant_initializer(0.0))
        gamma = tf.get_variable('gamma_conv', shape=[n_out], initializer=tf.constant_initializer(1.0))

        batch_mean_temp, batch_var_temp = tf.nn.moments(x, [0,1,2], name='moments')
        batch_mean = tf.get_variable('batch_mean', shape=batch_mean_temp.get_shape(), initializer=tf.constant_initializer(0.0), trainable=False)
        batch_var = tf.get_variable('batch_var', shape=batch_var_temp.get_shape(), initializer=tf.constant_initializer(0.0), trainable=False)

        mean_op = tf.assign(batch_mean, batch_mean_temp)
        var_op = tf.assign(batch_var, batch_var_temp)

        ema = tf.train.ExponentialMovingAverage(decay=0.5, zero_debias=False)
        ema_apply_op = ema.apply([batch_mean, batch_var])
        def mean_var_with_update():
            with tf.control_dependencies([mean_op, var_op]):
                with tf.control_dependencies([ema_apply_op]):
                    return tf.identity(batch_mean), tf.identity(batch_var)

        mean, var = tf.cond(phase_train,
                        mean_var_with_update,
                        lambda: (ema.average(batch_mean), ema.average(batch_var)))    
        normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-5)
    return normed

方法2:

def batch_norm_conv(x, n_out = 3, phase_train=True, scope='bn_conv'):
  with tf.variable_scope(scope):
      beta = tf.get_variable('beta_conv', shape=[n_out], initializer=tf.constant_initializer(0.0))
      gamma = tf.get_variable('gamma_conv', shape=[n_out], initializer=tf.constant_initializer(1.0))

      batch_mean, batch_var = tf.nn.moments(x, [0,1,2], name='moments')
      print(batch_mean.get_shape())
      mean_average = tf.get_variable('mean_average', shape=batch_mean.get_shape(), initializer=tf.constant_initializer(0.0))
      var_average = tf.get_variable('var_average', shape=batch_var.get_shape(), initializer=tf.constant_initializer(0.0))

      decay=0.5
      def mean_var_with_update():  
          mean_temp = decay * mean_average + (1-decay) * batch_mean
          var_temp = decay * var_average + (1-decay) * batch_var
          mean_op = tf.assign(mean_average, mean_temp)
          var_op = tf.assign(var_average, var_temp)
          with tf.control_dependencies([mean_op, var_op]):
              return tf.identity(batch_mean), tf.identity(batch_var)

      mean, var = tf.cond(phase_train, mean_var_with_update, lambda: (mean_average, var_average))
      normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3)
  return normed

方法3:

from tensorflow.python.training import moving_averages

def batch_norm_conv(x, n_out = 3, phase_train=True, scope='bn_conv'):

  with tf.variable_scope(scope):
      beta = tf.get_variable('beta_conv', shape=[n_out], initializer=tf.constant_initializer(0.0))
      gamma = tf.get_variable('gamma_conv', shape=[n_out], initializer=tf.constant_initializer(1.0))

      batch_mean, batch_var = tf.nn.moments(x, [0,1,2], name='moments')

      moving_mean = tf.get_variable('batch_mean', shape=batch_mean.get_shape(), initializer=tf.constant_initializer(0.0), trainable=False)
      moving_variance = tf.get_variable('batch_var', shape=batch_var.get_shape(), initializer=tf.constant_initializer(0.0), trainable=False)

      update_moving_mean = moving_averages.assign_moving_average(moving_mean,
                                                                 batch_mean, 0.5, zero_debias=False)
      update_moving_variance = moving_averages.assign_moving_average(
                                                          moving_variance, batch_var, 0.5, zero_debias=False)
      def mean_var_with_update():
          with tf.control_dependencies([update_moving_mean, update_moving_variance]):
                  return tf.identity(batch_mean), tf.identity(batch_var)

      mean, var = tf.cond(phase_train, 
                                        mean_var_with_update,
                                        lambda: (moving_mean, moving_variance))

      normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3)
  return normed

方法4:

def batch_norm_conv(x, phase_train=True, scope='bn_conv'):

  with tf.variable_scope(scope):
      normed = tf.contrib.layers.batch_norm(x, 
                                            center=True, 
                                            scale=True, 
                                            is_training = phase_train,
                                            decay = 0.5,
                                            trainable = True
                                           )

  return normed

方法1的结果:

('epoch', 0, 'test accuracy:', 0.12970000132918358)
('epoch', 1, 'test accuracy:', 0.20419999957084656)
('epoch', 2, 'test accuracy:', 0.11649999991059304)
('epoch', 3, 'test accuracy:', 0.12790000066161156)
('epoch', 4, 'test accuracy:', 0.17040000036358832)
('epoch', 5, 'test accuracy:', 0.15139999836683274)
('epoch', 6, 'test accuracy:', 0.13050000220537186)
('epoch', 7, 'test accuracy:', 0.15879999995231628)
('epoch', 8, 'test accuracy:', 0.17370000183582307)
('epoch', 9, 'test accuracy:', 0.17910000011324884)
('epoch', 10, 'test accuracy:', 0.17960000038146973)
('epoch', 11, 'test accuracy:', 0.12400000095367432)
('epoch', 12, 'test accuracy:', 0.13669999763369561)
('epoch', 13, 'test accuracy:', 0.25510000437498093)
('epoch', 14, 'test accuracy:', 0.18769999742507934)
('epoch', 15, 'test accuracy:', 0.16730000004172324)
('epoch', 16, 'test accuracy:', 0.15510000288486481)
('epoch', 17, 'test accuracy:', 0.19639999866485597)
('epoch', 18, 'test accuracy:', 0.24789999574422836)
('epoch', 19, 'test accuracy:', 0.15929999947547913)
('epoch', 20, 'test accuracy:', 0.17439999729394912)

方法2 - 4的结果(它们没有太大区别,所以只发布其中一个):

('epoch', 0, 'test accuracy:', 0.27250000238418581)
('epoch', 1, 'test accuracy:', 0.42709999978542329)
('epoch', 2, 'test accuracy:', 0.50179999470710757)
('epoch', 3, 'test accuracy:', 0.56709998846054077)
('epoch', 4, 'test accuracy:', 0.59760001301765442)
('epoch', 5, 'test accuracy:', 0.66010000705718996)
('epoch', 6, 'test accuracy:', 0.65400000214576726)
('epoch', 7, 'test accuracy:', 0.69880000352859495)
('epoch', 8, 'test accuracy:', 0.69749999642372129)
('epoch', 9, 'test accuracy:', 0.71029999256134035)
('epoch', 10, 'test accuracy:', 0.72619999051094053)
('epoch', 11, 'test accuracy:', 0.72920000553131104)
('epoch', 12, 'test accuracy:', 0.7372000098228455)
('epoch', 13, 'test accuracy:', 0.75380001068115232)
('epoch', 14, 'test accuracy:', 0.74269998073577881)
('epoch', 15, 'test accuracy:', 0.76199999451637268)
('epoch', 16, 'test accuracy:', 0.7636999785900116)
('epoch', 17, 'test accuracy:', 0.76039999723434448)
('epoch', 18, 'test accuracy:', 0.77150000333786006)
('epoch', 19, 'test accuracy:', 0.77920001149177553)
('epoch', 20, 'test accuracy:', 0.79100000858306885)

0 个答案:

没有答案