您可以在带有BatchNormalization的神经网络中解释Keras get_weights()函数吗?

时间:2019-07-18 04:45:46

标签: python keras

当我在Keras中运行神经网络(不使用BatchNormalization)时,我了解了get_weights()函数如何提供神经网络的权重和偏差。但是,使用BatchNorm会产生4个额外的参数,我假设是Gamma,Beta,Mean和Std。

当我保存这些值时,我试图手动复制一个简单的NN,但无法使它们产生正确的输出。有人知道这些值是如何工作的吗?

No Batch Norm

With Batch Norm

1 个答案:

答案 0 :(得分:1)

在简单的多层感知器(MLP)和具有批处理归一化(BN)的MLP的情况下,我将以一个示例来解释get_weights()。

示例:假设我们正在处理MNIST数据集,并使用2层MLP体系结构(即2个隐藏层)。隐藏层1中的神经元数量为392,隐藏层2中的神经元数量为196。因此,我们的MLP的最终架构将为784 x 512 x 196 x 10

这里784是输入图像尺寸,10是输出层尺寸

案例1:不具有批标准化的MLP =>让我的模型名称是使用ReLU激活功能的 model_relu 。现在,在训练 model_relu 之后,我正在使用get_weights(),这将返回大小6的列表,如下面的屏幕截图所示。

get_weights() with simple MLP and without Batch Norm并且列表值如下:

  • (784,392):隐藏的layer1的权重
  • (392,):与隐藏layer1的权重相关的偏差

  • (392,196):隐藏的layer2的权重

  • (196,):与隐藏层2的权重相关的偏差

  • (196,10):输出层的权重

  • (10,):与输出层权重相关的偏差

案例2:具有批处理规范化的MLP =>让我的模型名称为 model_batch ,该模型还使用ReLU激活功能以及批处理规范化。现在,在训练 model_batch 之后,我正在使用get_weights(),这将返回大小14的列表,如下面的屏幕截图所示。

get_weights() with Batch Norm 列表值如下:

  • (784,392):隐藏的layer1的权重
  • (392,):与隐藏layer1的权重相关的偏差
  • (392,)(392,)(392,)(392,):这四个参数是gamma,beta,mean和std。大小为392的dev值每个都与隐藏layer1的批次规范化相关。

  • (392,196):隐藏图层2的权重

  • (196,):与隐藏层2的权重相关的偏差
  • (196,)(196,)(196,)(196,):这四个参数是gamma,β,运行平均值和std。大小为196的dev与隐藏层2的批次标准化相关。

  • (196,10):输出层的权重

  • (10,):与输出层权重相关的偏差

因此,在case2中,如果要获取隐藏层1,隐藏层2和输出层的权重,则python代码可以是这样的:

wrights = model_batch.get_weights()      
hidden_layer1_wt = wrights[0].flatten().reshape(-1,1)     
hidden_layer2_wt = wrights[6].flatten().reshape(-1,1)     
output_layer_wt = wrights[12].flatten().reshape(-1,1)

希望这会有所帮助!

Ref: keras-BatchNormalization