理解:tf.contrib.quantize

时间:2018-02-21 16:54:32

标签: python tensorflow batch-normalization

在Tensorflow包tf.contrib.quantize中,有一个折叠批量规范图层的模块。它有一个名为freeze_batch_norm_delay的参数,用于冻结折叠批量范数层的移动平均值和方差。

我正在运行一些网络(MobileNet + SSD)并插入了tf.contrib.quantize支持。在30k步骤之后,批量规范被冻结(freeze_bn_delay = 30000)。这就是损失发生的事情:

Plot of the loss function, with batch norm freeze occuring at 30k steps

当冻结批量标准层时,损失会突然跳跃。我希望冻结之前和之后应该是相同的,除了平均值和方差不再更新("冻结")。

有人可以向我解释这些更正是什么吗?

源代码说明了这一点,但它没有帮助:

Computes batch norm correction params.
Before batch normalization is frozen:
We use batch statistics for batch norm.
  correction_scale = sigma_b/sigma_mv
  correction_recip = 1/correction_scale
  correction_offset = 0
After batch normalization is frozen:
  correction_scale = sigma_b/sigma_mv
  correction_recip = 1
  correction_offset =  gamma*(mu_b/sigma_b-mu_mv/sigma_mv).
Batch norm is frozen if global_step > bn_freeze_delay.

The corrections ensure that:
 a) The weights are quantized after scaling by gamma/sigma_mv. This enables
    smoother training as the scaling on the weights changes slowly, rather than
    jump across mini-batches
 b) Changing the values of the corrections allows for one to switch between
    using batch statistics to using moving mean and average, without requiring
    changes to batch_norm

这是函数定义: _ComputeBatchNormCorrections

我不明白为什么上面的文字声称我们在批量统计和移动均值和平均值之间切换,当冻结应该发生时。这是"冻结"?

我已经在网上搜索了一个答案,但是,由于这个软件包显然正在积极开发中,我没有找到任何解释。

0 个答案:

没有答案