批量标准化有哪些可学习的参数?

时间:2016-02-26 13:29:15

标签: tensorflow deep-learning

有四个变量存在问题:gamma,beta,均值移动平均值,方差移动平均值。

是否需要对移动平均值进行快照,并在测试时加载它们?

一个更好的问题:

对于张量流中的this implementation批量标准化,我是否需要将批处理均值和批处理变量从训练时间转移到测试时间?如果是这样,我怎样才能在tensorflow中实现呢?

1 个答案:

答案 0 :(得分:3)

是 - 对于任何批量规范化的使用,您可以通过基于单个批次的统计数据进行规范化来训练,但是然后使用统计数据的长期平均值来进行推理。

你应该保存你的均值和方差保持变量的副本,并在你进行测试时恢复它。

不应该有任何魔法要求:它们只是use the Saver时保存和恢复的变量。

在您引用的具体实现中,documentation for tf.train.ExponentialMovingAverage有一个具体示例,说明如何分别保存和恢复训练和推理的移动平均值。