Tensorflow> r1.0 tf.layers.batch_normalization非常差的测试性能

时间:2017-06-04 02:23:05

标签: tensorflow deep-learning

我尝试使用最新Tensorflow API中提供的tf.layers.batch_normalization函数来实现经常批量标准化的LSTM。

实现如下(我修改了TF源代码):



class BNLSTMCell(tf.nn.rnn_cell.RNNCell):
  """
  Batch Normalized Long short-term memory unit (LSTM) recurrent network cell.

  cf. Recurrent Batch Normalization
    https://arxiv.org/abs/1603.09025

  cf. A Gentle Guide to Using Batch Normalization in TensorFlow
    http://ruishu.io/2016/12/27/batchnorm/
  """

  def __init__(self, num_units, forward_only, gamma_c=1.0, gamma_h=1.0,
               gamma_x=1.0, beta_c=0.0, beta_h=0.0, beta_x=0.0, 
               input_size=None, use_peepholes=False, cell_clip=None,
               initializer=None, num_proj=None,
               num_unit_shards=1, num_proj_shards=1,
               forget_bias=1.0, state_is_tuple=False,
               activation=tf.tanh):
    """Initialize the parameters for an LSTM cell.

    Args:
      num_units: int, The number of units in the LSTM cell
      forward_only:
        If False (training):
          1. Normalize layer activations according to mini-batch statistics.
          2. During the training step, update population statistics
             approximation via moving average of mini-batch statistics.
        If True (testing):
          1. Normalize layer activations according to estimated population
             statistics.
          2. No update of population statistics according to mini-batch
             statistcs from test data.
      gamma_c: Scale of cell state normalization
      beta_c: Offset of cell state normalization
      gamma_h: Scale of hidden state normalization
      beta_h: Offset of hidden state normalization
        (set to 0 to avoid redundancy)
      gamma_x: Scale of input normalization
      beta_x: Offset of input normalization
        (set to 0 to avoid redundancy)
      input_size: Deprecated and unused.
      use_peepholes: bool, Set True to enable diagonal/peephole connections.
      cell_clip: (optional) A float value, if provided the cell state is clipped
        by this value prior to the cell output activation.
      initializer: (optional) The initializer to use for the weight and
        projection matrices.
      num_proj: (optional) int, The output dimensionality for the projection
        matrices.  If None, no projection is performed.
      num_unit_shards: How to split the weight matrix.  If >1, the weight
        matrix is stored across num_unit_shards.
      num_proj_shards: How to split the projection matrix.  If >1, the
        projection matrix is stored across num_proj_shards.
      forget_bias: Biases of the forget gate are initialized by default to 1
        in order to reduce the scale of forgetting at the beginning of
        the training.
      state_is_tuple: If True, accepted and returned states are 2-tuples of
        the `c_state` and `m_state`.  By default (False), they are concatenated
        along the column axis.  This default behavior will soon be deprecated.
      activation: Activation function of the inner states.
    """
    if not state_is_tuple:
      logging.warn(
          "%s: Using a concatenated state is slower and will soon be "
          "deprecated.  Use state_is_tuple=True." % self)
    if input_size is not None:
      logging.warn("%s: The input_size parameter is deprecated." % self)
    self._num_units = num_units
    self.forward_only = forward_only
    self._gamma_c = gamma_c
    self._beta_c = beta_c
    self._gamma_h = gamma_h
    self._beta_h = beta_h
    self._gamma_x = gamma_x
    self._beta_x = beta_x
    self._use_peepholes = use_peepholes
    self._cell_clip = cell_clip
    self._initializer = initializer
    self._num_proj = num_proj
    self._num_unit_shards = num_unit_shards
    self._num_proj_shards = num_proj_shards
    self._forget_bias = forget_bias
    self._state_is_tuple = state_is_tuple
    self._activation = activation

    if num_proj:
      self._state_size = (
          tf.nn.rnn_cell.LSTMStateTuple(num_units, num_proj)
          if state_is_tuple else num_units + num_proj)
      self._output_size = num_proj
    else:
      self._state_size = (
          tf.nn.rnn_cell.LSTMStateTuple(num_units, num_units)
          if state_is_tuple else 2 * num_units)
      self._output_size = num_units

  @property
  def state_size(self):
    return self._state_size

  @property
  def output_size(self):
    return self._output_size

  def __call__(self, inputs, state, scope=None):
    """Run one step of LSTM.

    Args:
      inputs: input Tensor, 2D, batch x num_units.
      state: if `state_is_tuple` is False, this must be a state Tensor,
        `2-D, batch x state_size`.  If `state_is_tuple` is True, this must be a
        tuple of state Tensors, both `2-D`, with column sizes `c_state` and
        `m_state`.
      scope: VariableScope for the created subgraph; defaults to "LSTMCell".

    Returns:
      A tuple containing:
      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
        LSTM after reading `inputs` when previous state was `state`.
        Here output_dim is:
           num_proj if num_proj was set,
           num_units otherwise.
      - Tensor(s) representing the new state of LSTM after reading `inputs` when
        the previous state was `state`.  Same type and shape(s) as `state`.

    Raises:
      ValueError: If input size cannot be inferred from inputs via
        static shape inference.
    """
    num_proj = self._num_units if self._num_proj is None else self._num_proj

    if self._state_is_tuple:
      (c_prev, m_prev) = state
    else:
      c_prev = tf.slice(state, [0, 0], [-1, self._num_units])
      m_prev = tf.slice(state, [0, self._num_units], [-1, num_proj])

    dtype = inputs.dtype
    input_size = inputs.get_shape().with_rank(2)[1]
    if input_size.value is None:
      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
    with tf.variable_scope(scope or type(self).__name__,
                           initializer=self._initializer):  # "LSTMCell"
      w_h = tf.get_variable("W_h", [num_proj, 4 * self._num_units],
                            dtype=tf.float32)
      w_x = tf.get_variable("W_x", [input_size.value, 4 * self._num_units],
                            dtype=tf.float32)

      b = tf.get_variable(
          "B", shape=[4 * self._num_units],
          initializer=tf.zeros_initializer, dtype=dtype)

      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
      hidden_matrix = tf.matmul(m_prev, w_h)
      bn_hidden_matrix = tf.layers.batch_normalization(hidden_matrix, 
        momentum=0.5,
        beta_initializer=tf.constant_initializer(self._beta_h), 
        gamma_initializer=tf.constant_initializer(self._gamma_h), 
        training=(not self.forward_only), 
        name='bn_hidden_matrix', reuse=None)
      # print(tf.get_collection(tf.GraphKeys.VARIABLES, scope=scope))
      input_matrix = tf.matmul(inputs, w_x)
      bn_input_matrix = tf.layers.batch_normalization(input_matrix, 
        momentum=0.5,
        beta_initializer=tf.constant_initializer(self._beta_x), 
        gamma_initializer=tf.constant_initializer(self._gamma_x), 
        training=(not self.forward_only), 
        name='bn_input_matrix', reuse=None)
      lstm_matrix = tf.nn.bias_add(
          tf.add(bn_input_matrix, bn_hidden_matrix), b)
      i, j, f, o = tf.split(lstm_matrix, num_or_size_splits=4, axis=1)

      # Diagonal connections
      if self._use_peepholes:
        w_f_diag = tf.get_variable(
            "W_F_diag", shape=[self._num_units], dtype=dtype)
        w_i_diag = tf.get_variable(
            "W_I_diag", shape=[self._num_units], dtype=dtype)
        w_o_diag = tf.get_variable(
            "W_O_diag", shape=[self._num_units], dtype=dtype)

      if self._use_peepholes:
        c = (tf.sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
             tf.sigmoid(i + w_i_diag * c_prev) * self._activation(j))
      else:
        c = (tf.sigmoid(f + self._forget_bias) * c_prev + tf.sigmoid(i) *
             self._activation(j))

      if self._cell_clip is not None:
        # pylint: disable=invalid-unary-operand-type
        c = tf.clip_by_value(c, -self._cell_clip, self._cell_clip)
        # pylint: enable=invalid-unary-operand-type

      bn_c = tf.layers.batch_normalization(c, 
        momentum=0.5,
        beta_initializer=tf.constant_initializer(self._beta_c),
        gamma_initializer=tf.constant_initializer(self._gamma_c), 
        training=(not self.forward_only), 
        name='bn_cell', reuse=None)
      if self._use_peepholes:
        m = tf.sigmoid(o + w_o_diag * bn_c) * self._activation(bn_c)
      else:
        m = tf.sigmoid(o) * self._activation(bn_c)

      if self._num_proj is not None:
        concat_w_proj = tf.nn.rnn_cell._get_concat_variable(
            "W_P", [self._num_units, self._num_proj],
            dtype, self._num_proj_shards)

        m = tf.matmul(m, concat_w_proj)

    new_state = (tf.nn.rnn_cell.LSTMStateTuple(c, m) if self._state_is_tuple
                 else tf.concat(1, [c, m]))
    return m, new_state




我构建了一个序列来序列模型,并在训练期间运行额外的更新,如其他帖子中所指定的。



 extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    if extra_update_ops and not forward_only:
        outputs, extra_updates = session.run([output_feed, extra_update_ops], input_feed)
    else:
        outputs = session.run(output_feed, input_feed)




训练损失看起来很合理。

但是,我的测试输出是垃圾。我想知道是否有人有类似的经历,并知道如何解决它。

0 个答案:

没有答案