如何将LSTMCell的变量设置为输入,而不是让它在Tensorflow中创建它?

时间:2018-08-12 00:01:34

标签: tensorflow lstm

当我创建一个tf.contrib.rnn.LSTMCell时,它将在初始化期间创建其内核 bias 可训练变量。

代码现在的外观:

cell_fw = tf.contrib.rnn.LSTMCell(hidden_size_char,
                        state_is_tuple=True)

我希望它看起来像什么

kernel = tf.get_variable(...)
bias = tf.get_variable(...)
cell_fw = tf.contrib.rnn.LSTMCell(kernel, bias, hidden_size,
                        state_is_tuple=True)

我想做的是自己创建这些变量,并在实例化它作为init的输入时将其提供给LSTMCell类。

有没有简单的方法可以做到这一点?我看了class source code,但似乎它在类的复杂层次结构中。

1 个答案:

答案 0 :(得分:1)

我继承了LSTMCell类,并更改了它的 init build 方法,以便它们接受给定的变量。如果变量在init中给出 在内部版本中,我们将不再使用 get_variable ,而将使用给定的内核变量和偏差变量。

虽然可能会有更清洁的方法。

_BIAS_VARIABLE_NAME = "bias"
_WEIGHTS_VARIABLE_NAME = "kernel"

class MyLSTMCell(tf.contrib.rnn.LSTMCell):
    def __init__(self, num_units,
                 use_peepholes=False, cell_clip=None,
                 initializer=None, num_proj=None, proj_clip=None,
                 num_unit_shards=None, num_proj_shards=None,
                 forget_bias=1.0, state_is_tuple=True,
                 activation=None, reuse=None, name=None, var_given=False, kernel=None, bias=None):

        super(MyLSTMCell, self).__init__(num_units,
                 use_peepholes=use_peepholes, cell_clip=cell_clip,
                 initializer=initializer, num_proj=num_proj, proj_clip=proj_clip,
                 num_unit_shards=num_unit_shards, num_proj_shards=num_proj_shards,
                 forget_bias=forget_bias, state_is_tuple=state_is_tuple,
                 activation=activation, reuse=reuse, name=name)

        self.var_given = var_given
        if self.var_given:
            self._kernel = kernel
            self._bias = bias


    def build(self, inputs_shape):
        if inputs_shape[1].value is None:
            raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
                             % inputs_shape)

        input_depth = inputs_shape[1].value
        h_depth = self._num_units if self._num_proj is None else self._num_proj
        maybe_partitioner = (
            partitioned_variables.fixed_size_partitioner(self._num_unit_shards)
            if self._num_unit_shards is not None
            else None)
        if self.var_given:
            # self._kernel and self._bais are already added in init
            pass
        else:
            self._kernel = self.add_variable(
                _WEIGHTS_VARIABLE_NAME,
                shape=[input_depth + h_depth, 4 * self._num_units],
                initializer=self._initializer,
                partitioner=maybe_partitioner)
            self._bias = self.add_variable(
                _BIAS_VARIABLE_NAME,
                shape=[4 * self._num_units],
                initializer=init_ops.zeros_initializer(dtype=self.dtype))
        if self._use_peepholes:
            self._w_f_diag = self.add_variable("w_f_diag", shape=[self._num_units],
                                               initializer=self._initializer)
            self._w_i_diag = self.add_variable("w_i_diag", shape=[self._num_units],
                                               initializer=self._initializer)
            self._w_o_diag = self.add_variable("w_o_diag", shape=[self._num_units],
                                               initializer=self._initializer)

        if self._num_proj is not None:
            maybe_proj_partitioner = (
                partitioned_variables.fixed_size_partitioner(self._num_proj_shards)
                if self._num_proj_shards is not None
                else None)
            self._proj_kernel = self.add_variable(
                "projection/%s" % _WEIGHTS_VARIABLE_NAME,
                shape=[self._num_units, self._num_proj],
                initializer=self._initializer,
                partitioner=maybe_proj_partitioner)

        self.built = True

所以代码将像这样:

kernel = get_variable(...)
bias = get_variable(...)
lstm_fw = MyLSTMCell(....., var_given=True, kernel=kernel, bias=bias)