Pytorch Batchnorm层与Keras Batchnorm不同

时间:2019-02-12 12:54:43

标签: python keras deep-learning pytorch batch-normalization

我试图将经过训练的BN权重从pytorch模型复制到等效的Keras模型,但是我一直得到不同的输出。

我阅读了Keras和Pytorch BN文档,我认为区别在于它们计算“均值”和“ var”的方式。

火炬:

  

均值和标准差是在   迷你批次

来源:Pytorch BatchNorm

因此,它们是样本的平均值。

凯拉斯:

  

axis:整数,应归一化的轴(通常为   特征轴)。例如,在Conv2D图层之后   data_format =“ channels_first”,在BatchNormalization中设置axis = 1。

来源:Keras BatchNorm

此处它们是功能(渠道)的平均值

正确的方法是什么?如何在模型之间传递BN权重?

1 个答案:

答案 0 :(得分:0)

您可以从pytorch模块的moving_meanmoving_variance属性中检索running_meanrunning_var

# torch weights, bias, running_mean, running_var corresponds to keras gamma, beta, moving mean, moving average

weights = torch_module.weight.numpy()  
bias = torch_module.bias.numpy()  
running_mean =  torch_module.running_mean.numpy()
running_var =  torch_module.running_var.numpy()

keras_module.set_weights([weights, bias, running_mean, running_var])