tf.keras.layers.BatchNormalization与trainable = False似乎不会更新其内部移动平均值和方差

时间:2020-10-05 06:46:07

标签: tensorflow tensorflow2.0 batch-normalization

我正试图找出BatchNormalization层在TensorFlow中的表现如何。我想出了以下代码,据我所知,这应该是一个完全有效的keras模型,但是BatchNormalization的均值和方差似乎没有更新。

来自文档https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization

在BatchNormalization层的情况下,在该层上设置trainable = False意味着该层随后将以推理模式运行(这意味着它将使用移动均值和移动方差对当前批次进行归一化,而不是使用当前批次的均值和方差)。

我希望模型在随后的每个预测调用中返回不同的值。 但是,我看到的是返回10次的完全相同的值。 谁能向我解释为什么BatchNormalization层不更新其内部值?

import tensorflow as tf
import numpy as np

if __name__ == '__main__':

    np.random.seed(1)
    x = np.random.randn(3, 5) * 5 + 0.3

    bn = tf.keras.layers.BatchNormalization(trainable=False, epsilon=1e-9)
    z = input = tf.keras.layers.Input([5])
    z = bn(z)

    model = tf.keras.Model(inputs=input, outputs=z)

    for i in range(10):
        print(x)
        print(model.predict(x))
        print()

我使用 TensorFlow 2.1.0

1 个答案:

答案 0 :(得分:1)

好的,我在假设中发现了错误。我正在训练中在训练期间对移动平均值进行了更新,而不是我想的那样。这是完全合理的,因为在推理过程中更新移动平均值可能会导致生产模型不稳定(例如,一连串的高度病理性输入样本[例如,其生成分布与训练网络时的生成分布完全不同)]可能会使网络产生偏差,并导致有效输入样本的性能下降。

当您对预训练的模型进行微调并希望冻结网络的某些层(即使在训练过程中)时,可训练参数也很有用。因为当您调用model.predict(x)(甚至是model(x)model(x, training=False))时,图层会自动使用移动平均值而不是批次平均值。

下面的代码清楚地展示了这一点

import tensorflow as tf
import numpy as np

if __name__ == '__main__':

    np.random.seed(1)
    x = np.random.randn(10, 5) * 5 + 0.3

    z = input = tf.keras.layers.Input([5])
    z = tf.keras.layers.BatchNormalization(trainable=True, epsilon=1e-9, momentum=0.99)(z)

    model = tf.keras.Model(inputs=input, outputs=z)
    
    # a dummy loss function
    model.compile(loss=lambda x, y: (x - y) ** 2)

    # a dummy fit just to update the batchnorm moving averages
    model.fit(x, x, batch_size=3, epochs=10)
    
    # first predict uses the moving averages from training
    pred = model(x).numpy()
    print(pred.mean(axis=0))
    print(pred.var(axis=0))
    print()
    
    # outputs the same thing as previous predict
    pred = model(x).numpy()
    print(pred.mean(axis=0))
    print(pred.var(axis=0))
    print()
    
    # here calling the model with training=True results in update of moving averages
    # furthermore, it uses the batch mean and variance as in training, 
    # so the result is very different
    pred = model(x, training=True).numpy()
    print(pred.mean(axis=0))
    print(pred.var(axis=0))
    print()
    
    # here we see again that the moving averages are used but they differ slightly after
    # the previous call, as expected
    pred = model(x).numpy()
    print(pred.mean(axis=0))
    print(pred.var(axis=0))
    print()

最后,我发现文档(https://www.tensorflow.org/api_docs/python/tf/keras/layers/BatchNormalization)提到了这一点:

  1. 使用包含批处理规范化的模型执行推理时,通常(尽管并非总是)希望使用累积统计信息而不是小批量统计信息。这可以通过在调用模型时使用training = False或使用model.predict来实现。

希望这会帮助将来有类似误解的人。