如果在测试时在训练模式中使用批量标准化怎么办?

时间:2017-09-19 02:34:50

标签: tensorflow normalization

批量标准化在训练阶段和测试阶段具有不同的行为。

例如,在tensorflow中使用tf.contrib.layers.batch_norm时,我们应该在不同阶段为is_training设置不同的值。

我的 qusetion 是:如果我在测试时仍设置is_training=True该怎么办?那就是说如果我仍然在测试阶段使用训练模式怎么办?

我提出这个问题的原因是,Pix2PixDualGAN的已发布代码在测试时不会设置is_training=False。并且似乎如果在测试时设置is_training=False,则生成的图像的质量可能非常糟糕。

有人可以解释一下吗?感谢。

1 个答案:

答案 0 :(得分:4)

在训练期间,BatchNorm层尝试做两件事:

  • 估算整个训练集(人口统计)的均值和方差
  • 对输入均值和方差进行归一化,使得它们的行为类似于高斯

在理想情况下,可以在第二点使用整个数据集的总体统计量。然而,这些是未知的并且在训练期间发生变化还有其他一些问题。

解决方法正在通过

对输入进行规范化
gamma * (x - mean) / sigma + b

基于小批量统计信息meansigma

在培训期间,小批量统计信息的运行平均值用于估算人口统计信息。

现在,原始BatchNorm公式使用整个数据集的近似均值和方差在推理期间进行标准化。由于网络是固定的,meanvariance的近似应该非常好。虽然现在使用人口统计数据似乎是有意义的,但这是一个关键的变化:从小批量统计数据到整个训练数据的统计数据。

批次在培训期间不是iid或批量非常小是至关重要的。 (但我也观察到它的批量为32)。

建议的BatchNorm隐含地假设两个统计信息非常相似。特别是,像pix2pix或dualgan一样对1号小批量的小批量训练给出了关于人口统计数据的非常糟糕的信息。在这种情况下,它们可能包含完全不同的值。

现在有一个深层网络,后期层希望输入是规范化批次(在小批量统计意义上)。请注意,他们接受过这种特殊数据的培训。但是使用整个数据集统计数据违反了推理过程中的假设。

如何解决这个问题?在您推荐的实现中,也可以在推理期间使用小批量统计信息。或者使用BatchReNormalization引入2个附加术语来消除小批量和人口统计之间的差异 或者只是使用InstanceNormalization(用于回归任务),实际上它与BatchNorm相同,但是单独处理批处理中的每个示例,也不使用人口统计。

我在研究期间也遇到过这个问题,现在用于回归任务InstanceNorm层。