带有符号API的Mxnet:批处理规范化更新

时间:2019-02-28 22:21:19

标签: c++ mxnet

我目前正在使用Mxnet和C ++ Symbol API来训练卷积神经网络。该网络包含一些Batchnormalization层,其中包含四个参数NDArray。在训练过程中,每批都应更新其中两个参数Moving_mean和moving_variance。

我猜想,由于执行器正向传递的布尔值设置为true,它将自动更新新参数。但是,由于某些原因,这两个NDArray保持不变,而没有任何参数更新。为何如此?此外,由于没有为这两个NDArray计算梯度,因为它不是“可学习的”参数,所以我无法通过常规的优化器更新功能来更新值。如何使用符号API告诉Mxnet更新moving_mean和Moving_variance NDArrays?

1 个答案:

答案 0 :(得分:1)

moving_meanmoving_variance在训练的后退过程中更新,而不是像其他参数一样在优化步骤中更新。如果您在BatchNorm层上设置了use_global_stats=True,那么在训练过程中这些参数仍可以保持不变的另一个原因。