如何在不考虑方差的情况下删除Keras层的均值(如Batchnormalization)?

时间:2020-10-30 19:43:11

标签: tensorflow keras normalization centering batch-normalization

我想做Keras中BatchNormalization层的工作,去掉均值并存储移动平均值。不幸的是,BatchNormalization layer in Keras总是也考虑方差,我不想使用它。

我当时正在考虑使用“平均”和“减”图层,但是在培训结束时它们不会存储任何可使用的图层。这个想法是我的图层删除并学习了平均值,因此在测试时进行预测时,它将减去一个恒定值。

2 个答案:

答案 0 :(得分:0)

我为此创建了一个Centering层,是从BatchNormalization code复制而来的。它使用动量来移动当前的移动平均值。看来可行,我可以用它保存和加载模型。

from tensorflow.keras import backend
from tensorflow.keras import initializers
from tensorflow.keras import layers
from tensorflow import math
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import ops
from tensorflow.python.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables


class Centering(layers.Layer):
    """Layer that centers the data learning a mean."""

    def __init__(self, momentum=0.01, **kwargs):
        """Constructor of LatentProjection."""
        if 'input_shape' not in kwargs and 'input_dim' in kwargs:
            kwargs['input_shape'] = (kwargs.pop('input_dim'),)
        super().__init__(**kwargs)
        self.input_spec = layers.InputSpec(min_ndim=2)
        self.momentum = momentum
        self.moving_mean = None

    def build(self, input_shape):
        """Create internal variables."""
        assert len(input_shape) >= 2
        input_dim = input_shape[-1]
        self.moving_mean = self.add_weight(
            name='moving_mean',
            shape=(input_dim,),
            initializer=initializers.Zeros,
            synchronization=variables.VariableSynchronization.ON_READ,
            trainable=False,
            aggregation=variables.VariableAggregation.MEAN,
            experimental_autocast=False)
        self.input_spec = layers.InputSpec(min_ndim=2, axes={-1: input_dim})
        self.built = True

    def _get_training_value(self, training=None):
        """Copied from normalization.py."""
        if training is None:
            training = backend.learning_phase()
        if isinstance(training, int):
            training = bool(training)
        if not self.trainable:
            # When the layer is not trainable, it overrides the value
            # passed from model.
            training = False
        return training

    def _support_zero_size_input(self):
        """Copied from normalization.py."""
        return distribution_strategy_context.has_strategy() and getattr(
            distribution_strategy_context.get_strategy().extended,
            'experimental_enable_get_next_as_optional', False)

    def _assign_moving_average(self, variable, value, momentum, inputs_size):
        """Copied from normalization.py."""
        with backend.name_scope('AssignMovingAvg') as scope:
            with ops.colocate_with(variable):
                decay = ops.convert_to_tensor_v2(1.0 - momentum, name='decay')
                if decay.dtype != variable.dtype.base_dtype:
                    decay = math_ops.cast(decay, variable.dtype.base_dtype)
                update_delta = (variable -
                                math_ops.cast(value, variable.dtype)) * decay
                if inputs_size is not None:
                    update_delta = array_ops.where(
                        inputs_size > 0, update_delta,
                        backend.zeros_like(update_delta))
                return state_ops.assign_sub(variable, update_delta, name=scope)

    def call(self, inputs, training=None, **kwargs):
        """Called for each mini batch when applied to input layer."""
        training = self._get_training_value(training)
        training_value = tf_utils.constant_value(training)
        if training_value == False:
            mean = self.moving_mean
        else:
            mean = math.reduce_mean(inputs, axis=0)
            # Following code copied from normalization.py to update moving mean
            if self._support_zero_size_input():
                # Keras assumes that batch dimension is the first dimension for
                # Batch Normalization.
                input_batch_size = array_ops.shape(inputs)[0]
            else:
                input_batch_size = None

            def mean_update():
                """Perform update of moving mean average using copied code."""
                self._assign_moving_average(
                    self.moving_mean, mean, self.momentum, input_batch_size)
            self.add_update(mean_update)
        # Center inputs
        return inputs - mean

    def get_config(self):
        """Internal config of this layer."""
        config = {
            'momentum': self.momentum,
        }
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))

答案 1 :(得分:0)

您还可以在 BatchNormalization 中禁用缩放

gamma 是一个学习的缩放因子(初始化为 1),可以通过将 scale=False 传递给构造函数来禁用它。