如何使用tf.nn.batch_normalization处理移动平均值和移动方差?

时间:2018-09-14 14:20:14

标签: python tensorflow batch-normalization

对于我的实现,我必须先定义权重,并且不能在tensorflow中使用高级函数,例如tf.layers.batch_normalization或tf.layers.dense。因此,要进行批处理规范化,我需要使用tf.nn.batch_normalization。我知道,可以使用tf.nn.moments来计算每个小批量的均值和方差,但是移动均值和方差又如何呢?是否有人有这样做的经验或知道实施示例?我看到人们谈论使用tf.nn.batch_normalization可能很棘手,所以我想知道这样做的复杂性。换句话说,是什么使它变得棘手?在实施过程中应注意哪些方面?除了移动平均线和方差之外,我还有其他需要注意的地方吗?

1 个答案:

答案 0 :(得分:2)

您必须对running_meanrunning_variance这两个术语保持警惕。在数学和传统计算机科学中,它们被称为在没有看到完整数据的情况下计算这些值的方法。它们也称为onlinemean的{​​{1}}版本。并不是说他们能够事先准确确定variancemean。随着输入更多数据,它们只是继续更新某些变量variancemean的值。如果您的数据量有限,那么一旦看到完整的数据,它们的值将与值1匹配。如果可以获取完整的数据,则可以进行计算。

批量归一化的情况不同。您不应以与上段相同的方式来思考variancerunning mean

培训时间

在训练期间,为running variance计算meanvariance。它们不是batchrunning mean。因此,您可以安全地使用running variance来做到这一点。

测试时间

在测试期间,您将使用称为tf.nn.momentspopulation_estimated_mean的名称。这些数量是在训练期间计算的,但不能直接使用。计算它们以供以后在测试期间使用。

现在有一个陷阱,就是有些人可能想使用population_estimated_variance来计算这些数量。不建议这样做。 为什么? :因为,培训是在多个Knuth Formula上完成的。因此,同一数据集被看到的次数与epochs的数量一样多。由于数据扩充通常也是随机的,因此计算标准epochsrunning mean可能很危险。相反,通常使用的是running variance

您可以通过在exponentially decaying estimatetf.train.ExponentialMovingAverage上使用batch_mean来实现此目的。在这里,您可以指定与过去的样本相对于当前的样本有多少相关性。通过设置batch_variance,确保用于计算此变量的变量为non-trainable

在测试期间,您将这些变量用作trainable=Falsemean

有关实施的更多详细信息,请查看this link