如何在TensorFlow矩中设置轴参数以进行批量归一化?

时间:2019-11-07 17:11:35

标签: python tensorflow tensorflow-datasets batch-normalization

我计划使用this blog来实现类似于tf.nn.batch_normalization的批处理归一化功能(或仅使用tf.nn.moments),以计算均值和方差,但我希望对时态数据执行此操作,矢量和图片类型。我通常在理解如何在axes中正确设置tf.nn.moments参数时遇到一些麻烦。

我的矢量序列输入数据的形状为(batch, timesteps, channels),而我的图像序列输入数据的形状为(batch, timesteps, height, width, 3)(请注意,它们是RGB图像)。在这两种情况下,我都希望在整个批次中跨时间步进行归一化,这意味着我试图针对不同的时间步保持独立的均值/方差。

如何为不同的数据类型(例如图像,矢量)和时态/非时态正确设置axes

1 个答案:

答案 0 :(得分:1)

最简单的想法是-传递到axes的轴将被折叠 ,并且统计信息将通过切片 {{1 }}。示例:

axes
import tensorflow as tf

x = tf.random.uniform((8, 10, 4))

print(x, '\n')
print(tf.nn.moments(x, axes=[0]), '\n')
print(tf.nn.moments(x, axes=[0, 1]))

从源头上,math_ops.reduce_mean用于计算Tensor("random_uniform:0", shape=(8, 10, 4), dtype=float32) (<tf.Tensor 'moments/Squeeze:0' shape=(10, 4) dtype=float32>, <tf.Tensor 'moments/Squeeze_1:0' shape=(10, 4) dtype=float32>) (<tf.Tensor 'moments_1/Squeeze:0' shape=(4,) dtype=float32>, <tf.Tensor 'moments_1/Squeeze_1:0' shape=(4,) dtype=float32>) mean,它们的工作方式如下:

variance

换句话说,# axes = [0] mean = (x[0, :, :] + x[1, :, :] + ... + x[7, :, :]) / 8 mean.shape == (10, 4) # each slice's shape is (10, 4), so sum's shape is also (10, 4) # axes = [0, 1] mean = (x[0, 0, :] + x[1, 0, :] + ... + x[7, 0, :] + x[0, 1, :] + x[1, 1, :] + ... + x[7, 1, :] + ... + x[0, 10, :] + x[1, 10, :] + ... + x[7, 10, :]) / (8 * 10) mean.shape == (4, ) # each slice's shape is (4, ), so sum's shape is also (4, ) 将针对axes=[0]计算(timesteps, channels)统计信息-即,对samples进行迭代,计算samples切片的均值和方差。因此,对于

  

标准化将在整个批次和整个时间步中进行,这意味着我尝试为不同的时间步保持单独的均值/方差

您只需要折叠(timesteps, channels)维度(沿着timesteps),并通过迭代samplessamples来计算统计信息:

timesteps

与图像的故事相同,除了具有两个非通道/样本尺寸外,您将执行axes = [0, 1] (折叠axes = [0, 1, 2])。


伪代码演示:查看实际计算

samples, height, width
import tensorflow as tf
import tensorflow.keras.backend as K
import numpy as np

x = tf.constant(np.random.randn(8, 10, 4))
result1 = tf.add(x[0], tf.add(x[1], tf.add(x[2], tf.add(x[3], tf.add(x[4], 
                       tf.add(x[5], tf.add(x[6], x[7]))))))) / 8
result2 = tf.reduce_mean(x, axis=0)
print(K.eval(result1 - result2))