Tensor Flow-具有注意机的LSTM-'Tensor'对象不可迭代

时间:2018-09-26 14:58:16

标签: tensorflow lstm

我基于tensorflow LSTM类定义了一个已更改的LSTM单元,尝试在其上添加关注机,如下所示:

class MyLSTMCell(tf.nn.rnn_cell.LSTMCell):
def __init__(self, num_units, input_size, grid_size,
             use_peepholes=True, cell_clip=None,
             initializer=tf.random_normal_initializer(stddev=0.01), num_proj=None, proj_clip=None,
             num_unit_shards=None, num_proj_shards=None,
             forget_bias=1.0, batch_size=32, state_is_tuple=False,
             activation=None, reuse=None, name=None, with_att=True):
    super(tf.nn.rnn_cell.LSTMCell, self).__init__(_reuse=reuse, name=name)

    # self.input_spec = tf.nn.rnn_cell.base_layer.InputSpec(ndim=2)
    self._batch_size = batch_size
    self._num_units = num_units
    self._input_size = input_size
    self._grid_size = grid_size
    self._use_peepholes = use_peepholes
    self._cell_clip = cell_clip
    self._initializer = initializer
    self._num_proj = num_proj
    self._proj_clip = proj_clip
    self._num_unit_shards = num_unit_shards
    self._num_proj_shards = num_proj_shards
    self._forget_bias = forget_bias
    self._state_is_tuple = state_is_tuple
    self._activation = activation or tf.tanh
    self._with_att = with_att

    if num_proj:
        self._state_size = (
            tf.nn.rnn_cell.LSTMStateTuple(num_units, num_proj)
            if state_is_tuple else num_units + num_proj)
        self._output_size = num_proj
    else:
        self._state_size = (
            tf.nn.rnn_cell.LSTMStateTuple(num_units, num_units)
            if state_is_tuple else 2 * num_units)
        self._output_size = num_units

@property
def state_size(self):
    return self._state_size

@property
def output_size(self):
    return self._output_size

def build(self, inputs_shape):
    if inputs_shape[2].value is None:
        raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
                         % inputs_shape)

    input_depth = inputs_shape[2].value
    h_depth = self._num_units if self._num_proj is None else self._num_proj
    maybe_partitioner = (
        tf.nn.rnn_cell.partitioned_variables.fixed_size_partitioner(self._num_unit_shards)
        if self._num_unit_shards is not None
        else None)
    self._kernel = self.add_variable(
        "kernal",
        shape=[input_depth + h_depth, 4 * self._num_units],
        initializer=self._initializer,
        partitioner=maybe_partitioner)
    self._bias = self.add_variable(
        "bias",
        shape=[4 * self._num_units],
        initializer=tf.zeros_initializer(dtype=self.dtype))
    if self._use_peepholes:
        self._w_f_diag = self.add_variable("w_f_diag", shape=[self._num_units],
                                           initializer=self._initializer)
        self._w_i_diag = self.add_variable("w_i_diag", shape=[self._num_units],
                                           initializer=self._initializer)
        self._w_o_diag = self.add_variable("w_o_diag", shape=[self._num_units],
                                           initializer=self._initializer)

    if self._num_proj is not None:
        maybe_proj_partitioner = (
            tf.nn.rnn_cell.partitioned_variables.fixed_size_partitioner(self._num_proj_shards)
            if self._num_proj_shards is not None
            else None)
        self._proj_kernel = self.add_variable(
            "projection/%s" % "kernal",
            shape=[self._num_units, self._num_proj],
            initializer=self._initializer,
            partitioner=maybe_proj_partitioner)

    self.built = True

def call(self, inputs, state):
    print(state.shape)
    num_proj = self._num_units if self._num_proj is None else self._num_proj
    sigmoid = tf.sigmoid
    if self._state_is_tuple:
        (c_prev, h_prev) = state
    else:
        c_prev = tf.slice(state, [0, 0], [-1, self._num_units])
        h_prev = tf.slice(state, [0, self._num_units], [-1, num_proj])

    print(inputs.shape)
    input_size = inputs.get_shape().with_rank(3)[2]
    if input_size.value is None:
        raise ValueError("Could not infer input size from inputs.get_shape()[-1]")

    Kx = tf.layers.dense(inputs, self._input_size, use_bias=False)
    print('begin')
    print(Kx.shape)
    Uh = tf.layers.dense(h_prev, self._input_size, use_bias=True)
    print(Uh.shape)
    e = tf.tanh(Uh + Kx)
    print(e.shape)
    e = tf.layers.dense(e, self._grid_size)
    print(e.shape)
    alpha = tf.nn.softmax(e, axis=2)
    print(alpha.shape)
    z = tf.reduce_mean(tf.matmul(alpha, inputs), axis=1)
    print(z.shape)
    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
    lstm_matrix = tf.matmul(
        tf.concat([z, h_prev], 1), self._kernel)
    lstm_matrix = tf.nn.bias_add(lstm_matrix, self._bias)
    i, g, f, o = tf.split(
        value=lstm_matrix, num_or_size_splits=4, axis=1)
    # Diagonal connections
    if self._use_peepholes:
        c = (sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev +
             sigmoid(i + self._w_i_diag * c_prev) * self._activation(g))
    else:
        c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
             self._activation(g))

    if self._cell_clip is not None:
        # pylint: disable=invalid-unary-operand-type
        c = tf.nn.rnn_cell_impl.clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
        # pylint: enable=invalid-unary-operand-type
    if self._use_peepholes:
        h = sigmoid(o + self._w_o_diag * c) * self._activation(c)
    else:
        h = sigmoid(o) * self._activation(c)

    if self._num_proj is not None:
        h = tf.matmul(h, self._proj_kernel)

        if self._proj_clip is not None:
            # pylint: disable=invalid-unary-operand-type
            h = tf.nn.rnn_cell_impl.clip_ops.clip_by_value(h, -self._proj_clip, self._proj_clip)
            # pylint: enable=invalid-unary-operand-type
    new_state = (tf.nn.rnn_cell.LSTMStateTuple(c, h) if self._state_is_tuple else
                 tf.concat([c, h], 1))
    return h, new_state

我想将第一个状态初始化为:

c0 = tf.layers.dense(tf.reduce_mean(feature, axis=1), num_units, activation=tf.tanh)
h0 = tf.layers.dense(tf.reduce_mean(feature, axis=1), num_units, activation=tf.tanh)
init_state = tf.nn.rnn_cell.LSTMStateTuple(c0,h0)

但是我得到一个错误:Tensor在急切执行时对象不是不可迭代的,我真的很注意state_size,不知道为什么,有人可以帮助我吗?谢谢!

0 个答案:

没有答案