我尝试使用最新Tensorflow API中提供的tf.layers.batch_normalization
函数来实现经常批量标准化的LSTM。
实现如下(我修改了TF源代码):
class BNLSTMCell(tf.nn.rnn_cell.RNNCell):
"""
Batch Normalized Long short-term memory unit (LSTM) recurrent network cell.
cf. Recurrent Batch Normalization
https://arxiv.org/abs/1603.09025
cf. A Gentle Guide to Using Batch Normalization in TensorFlow
http://ruishu.io/2016/12/27/batchnorm/
"""
def __init__(self, num_units, forward_only, gamma_c=1.0, gamma_h=1.0,
gamma_x=1.0, beta_c=0.0, beta_h=0.0, beta_x=0.0,
input_size=None, use_peepholes=False, cell_clip=None,
initializer=None, num_proj=None,
num_unit_shards=1, num_proj_shards=1,
forget_bias=1.0, state_is_tuple=False,
activation=tf.tanh):
"""Initialize the parameters for an LSTM cell.
Args:
num_units: int, The number of units in the LSTM cell
forward_only:
If False (training):
1. Normalize layer activations according to mini-batch statistics.
2. During the training step, update population statistics
approximation via moving average of mini-batch statistics.
If True (testing):
1. Normalize layer activations according to estimated population
statistics.
2. No update of population statistics according to mini-batch
statistcs from test data.
gamma_c: Scale of cell state normalization
beta_c: Offset of cell state normalization
gamma_h: Scale of hidden state normalization
beta_h: Offset of hidden state normalization
(set to 0 to avoid redundancy)
gamma_x: Scale of input normalization
beta_x: Offset of input normalization
(set to 0 to avoid redundancy)
input_size: Deprecated and unused.
use_peepholes: bool, Set True to enable diagonal/peephole connections.
cell_clip: (optional) A float value, if provided the cell state is clipped
by this value prior to the cell output activation.
initializer: (optional) The initializer to use for the weight and
projection matrices.
num_proj: (optional) int, The output dimensionality for the projection
matrices. If None, no projection is performed.
num_unit_shards: How to split the weight matrix. If >1, the weight
matrix is stored across num_unit_shards.
num_proj_shards: How to split the projection matrix. If >1, the
projection matrix is stored across num_proj_shards.
forget_bias: Biases of the forget gate are initialized by default to 1
in order to reduce the scale of forgetting at the beginning of
the training.
state_is_tuple: If True, accepted and returned states are 2-tuples of
the `c_state` and `m_state`. By default (False), they are concatenated
along the column axis. This default behavior will soon be deprecated.
activation: Activation function of the inner states.
"""
if not state_is_tuple:
logging.warn(
"%s: Using a concatenated state is slower and will soon be "
"deprecated. Use state_is_tuple=True." % self)
if input_size is not None:
logging.warn("%s: The input_size parameter is deprecated." % self)
self._num_units = num_units
self.forward_only = forward_only
self._gamma_c = gamma_c
self._beta_c = beta_c
self._gamma_h = gamma_h
self._beta_h = beta_h
self._gamma_x = gamma_x
self._beta_x = beta_x
self._use_peepholes = use_peepholes
self._cell_clip = cell_clip
self._initializer = initializer
self._num_proj = num_proj
self._num_unit_shards = num_unit_shards
self._num_proj_shards = num_proj_shards
self._forget_bias = forget_bias
self._state_is_tuple = state_is_tuple
self._activation = activation
if num_proj:
self._state_size = (
tf.nn.rnn_cell.LSTMStateTuple(num_units, num_proj)
if state_is_tuple else num_units + num_proj)
self._output_size = num_proj
else:
self._state_size = (
tf.nn.rnn_cell.LSTMStateTuple(num_units, num_units)
if state_is_tuple else 2 * num_units)
self._output_size = num_units
@property
def state_size(self):
return self._state_size
@property
def output_size(self):
return self._output_size
def __call__(self, inputs, state, scope=None):
"""Run one step of LSTM.
Args:
inputs: input Tensor, 2D, batch x num_units.
state: if `state_is_tuple` is False, this must be a state Tensor,
`2-D, batch x state_size`. If `state_is_tuple` is True, this must be a
tuple of state Tensors, both `2-D`, with column sizes `c_state` and
`m_state`.
scope: VariableScope for the created subgraph; defaults to "LSTMCell".
Returns:
A tuple containing:
- A `2-D, [batch x output_dim]`, Tensor representing the output of the
LSTM after reading `inputs` when previous state was `state`.
Here output_dim is:
num_proj if num_proj was set,
num_units otherwise.
- Tensor(s) representing the new state of LSTM after reading `inputs` when
the previous state was `state`. Same type and shape(s) as `state`.
Raises:
ValueError: If input size cannot be inferred from inputs via
static shape inference.
"""
num_proj = self._num_units if self._num_proj is None else self._num_proj
if self._state_is_tuple:
(c_prev, m_prev) = state
else:
c_prev = tf.slice(state, [0, 0], [-1, self._num_units])
m_prev = tf.slice(state, [0, self._num_units], [-1, num_proj])
dtype = inputs.dtype
input_size = inputs.get_shape().with_rank(2)[1]
if input_size.value is None:
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
with tf.variable_scope(scope or type(self).__name__,
initializer=self._initializer): # "LSTMCell"
w_h = tf.get_variable("W_h", [num_proj, 4 * self._num_units],
dtype=tf.float32)
w_x = tf.get_variable("W_x", [input_size.value, 4 * self._num_units],
dtype=tf.float32)
b = tf.get_variable(
"B", shape=[4 * self._num_units],
initializer=tf.zeros_initializer, dtype=dtype)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
hidden_matrix = tf.matmul(m_prev, w_h)
bn_hidden_matrix = tf.layers.batch_normalization(hidden_matrix,
momentum=0.5,
beta_initializer=tf.constant_initializer(self._beta_h),
gamma_initializer=tf.constant_initializer(self._gamma_h),
training=(not self.forward_only),
name='bn_hidden_matrix', reuse=None)
# print(tf.get_collection(tf.GraphKeys.VARIABLES, scope=scope))
input_matrix = tf.matmul(inputs, w_x)
bn_input_matrix = tf.layers.batch_normalization(input_matrix,
momentum=0.5,
beta_initializer=tf.constant_initializer(self._beta_x),
gamma_initializer=tf.constant_initializer(self._gamma_x),
training=(not self.forward_only),
name='bn_input_matrix', reuse=None)
lstm_matrix = tf.nn.bias_add(
tf.add(bn_input_matrix, bn_hidden_matrix), b)
i, j, f, o = tf.split(lstm_matrix, num_or_size_splits=4, axis=1)
# Diagonal connections
if self._use_peepholes:
w_f_diag = tf.get_variable(
"W_F_diag", shape=[self._num_units], dtype=dtype)
w_i_diag = tf.get_variable(
"W_I_diag", shape=[self._num_units], dtype=dtype)
w_o_diag = tf.get_variable(
"W_O_diag", shape=[self._num_units], dtype=dtype)
if self._use_peepholes:
c = (tf.sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
tf.sigmoid(i + w_i_diag * c_prev) * self._activation(j))
else:
c = (tf.sigmoid(f + self._forget_bias) * c_prev + tf.sigmoid(i) *
self._activation(j))
if self._cell_clip is not None:
# pylint: disable=invalid-unary-operand-type
c = tf.clip_by_value(c, -self._cell_clip, self._cell_clip)
# pylint: enable=invalid-unary-operand-type
bn_c = tf.layers.batch_normalization(c,
momentum=0.5,
beta_initializer=tf.constant_initializer(self._beta_c),
gamma_initializer=tf.constant_initializer(self._gamma_c),
training=(not self.forward_only),
name='bn_cell', reuse=None)
if self._use_peepholes:
m = tf.sigmoid(o + w_o_diag * bn_c) * self._activation(bn_c)
else:
m = tf.sigmoid(o) * self._activation(bn_c)
if self._num_proj is not None:
concat_w_proj = tf.nn.rnn_cell._get_concat_variable(
"W_P", [self._num_units, self._num_proj],
dtype, self._num_proj_shards)
m = tf.matmul(m, concat_w_proj)
new_state = (tf.nn.rnn_cell.LSTMStateTuple(c, m) if self._state_is_tuple
else tf.concat(1, [c, m]))
return m, new_state

我构建了一个序列来序列模型,并在训练期间运行额外的更新,如其他帖子中所指定的。
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if extra_update_ops and not forward_only:
outputs, extra_updates = session.run([output_feed, extra_update_ops], input_feed)
else:
outputs = session.run(output_feed, input_feed)

训练损失看起来很合理。
但是,我的测试输出是垃圾。我想知道是否有人有类似的经历,并知道如何解决它。