张量流中的批量归一化:变量和性能

时间:2019-04-02 12:23:27

标签: python tensorflow batch-normalization

我想在批处理规范化层的变量上添加条件操作。具体来说,先进行浮动训练,然后在微调的辅助训练阶段进行量化。为此,我想对变量(均值和var的刻度,移位和exp移动平均值)添加tf.cond操作。

我用编写的batchnorm层替换了tf.layers.batch_normalization(见下文)。

此函数运行完美(即,我在两个函数中都获得了相同的指标),并且可以在变量中添加任何管道(在batchnorm操作之前)。 问题是性能(运行时)急剧下降(即,用我自己的函数替换layers.batchnorm就是一个x2因子,如下所述)。

def batchnorm(self, x, name, epsilon=0.001, decay=0.99):
    epsilon = tf.to_float(epsilon)
    decay = tf.to_float(decay)
    with tf.variable_scope(name):
        shape = x.get_shape().as_list()
        channels_num = shape[3]
        # scale factor
        gamma = tf.get_variable("gamma", shape=[channels_num], initializer=tf.constant_initializer(1.0), trainable=True)
        # shift value
        beta = tf.get_variable("beta", shape=[channels_num], initializer=tf.constant_initializer(0.0), trainable=True)
        moving_mean = tf.get_variable("moving_mean", channels_num, initializer=tf.constant_initializer(0.0), trainable=False)
        moving_var = tf.get_variable("moving_var", channels_num, initializer=tf.constant_initializer(1.0), trainable=False)
        batch_mean, batch_var = tf.nn.moments(x, axes=[0, 1, 2]) # per channel

        update_mean = moving_mean.assign((decay * moving_mean) + ((1. - decay) * batch_mean))
        update_var = moving_var.assign((decay * moving_var) + ((1. - decay) * batch_var))

        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean)
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_var)

        bn_mean = tf.cond(self.is_training, lambda: tf.identity(batch_mean), lambda: tf.identity(moving_mean))
        bn_var = tf.cond(self.is_training, lambda: tf.identity(batch_var), lambda: tf.identity(moving_var))

        with tf.variable_scope(name + "_batchnorm_op"):
            inv = tf.math.rsqrt(bn_var + epsilon)
            inv *= gamma
            output = ((x*inv) - (bn_mean*inv)) + beta

    return output

对于以下任何问题的帮助,我们将不胜感激:

  • 关于如何提高解决方案性能(减少运行时间)的任何想法吗?
  • 是否可以在batchnorm操作之前将自己的运算符添加到layer.batchnorm的变量管道中?
  • 对同一问题有其他解决方案吗?

谢谢!

1 个答案:

答案 0 :(得分:1)

tf.nn.fused_batch_norm已经过优化并达到了目的。

我必须创建两个子图,每个模式一个,因为fused_batch_norm的界面不采用条件训练/测试模式(is_training是布尔型而不是张量,因此它的图形不是条件性的)。我在之后添加了条件(见下文)。但是,即使有两个子图,它也具有相同的tf.layers.batch_normalization运行时间。

这是最终的解决方案(我仍然感谢您提出任何改进意见或建议):

def batchnorm(self, x, name, epsilon=0.001, decay=0.99):
    with tf.variable_scope(name):
        shape = x.get_shape().as_list()
        channels_num = shape[3]
        # scale factor
        gamma = tf.get_variable("gamma", shape=[channels_num], initializer=tf.constant_initializer(1.0), trainable=True)
        # shift value
        beta = tf.get_variable("beta", shape=[channels_num], initializer=tf.constant_initializer(0.0), trainable=True)
        moving_mean = tf.get_variable("moving_mean", channels_num, initializer=tf.constant_initializer(0.0), trainable=False)
        moving_var = tf.get_variable("moving_var", channels_num, initializer=tf.constant_initializer(1.0), trainable=False)

        (output_train, batch_mean, batch_var) = tf.nn.fused_batch_norm(x,
                                                                 gamma,
                                                                 beta,  # pylint: disable=invalid-name
                                                                 mean=None,
                                                                 variance=None,
                                                                 epsilon=epsilon,
                                                                 data_format="NHWC",
                                                                 is_training=True,
                                                                 name="_batchnorm_op")
        (output_test, _, _) = tf.nn.fused_batch_norm(x,
                                                     gamma,
                                                     beta,  # pylint: disable=invalid-name
                                                     mean=moving_mean,
                                                     variance=moving_var,
                                                     epsilon=epsilon,
                                                     data_format="NHWC",
                                                     is_training=False,
                                                     name="_batchnorm_op")

        output = tf.cond(self.is_training, lambda: tf.identity(output_train), lambda: tf.identity(output_test))

        update_mean = moving_mean.assign((decay * moving_mean) + ((1. - decay) * batch_mean))
        update_var = moving_var.assign((decay * moving_var) + ((1. - decay) * batch_var))
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_mean)
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_var)

    return output