当我创建一个tf.contrib.rnn.LSTMCell时,它将在初始化期间创建其内核和 bias 可训练变量。
代码现在的外观:
cell_fw = tf.contrib.rnn.LSTMCell(hidden_size_char,
state_is_tuple=True)
我希望它看起来像什么
kernel = tf.get_variable(...)
bias = tf.get_variable(...)
cell_fw = tf.contrib.rnn.LSTMCell(kernel, bias, hidden_size,
state_is_tuple=True)
我想做的是自己创建这些变量,并在实例化它作为init的输入时将其提供给LSTMCell类。
有没有简单的方法可以做到这一点?我看了class source code,但似乎它在类的复杂层次结构中。
答案 0 :(得分:1)
我继承了LSTMCell类,并更改了它的 init 和 build 方法,以便它们接受给定的变量。如果变量在init中给出 在内部版本中,我们将不再使用 get_variable ,而将使用给定的内核变量和偏差变量。
虽然可能会有更清洁的方法。
_BIAS_VARIABLE_NAME = "bias"
_WEIGHTS_VARIABLE_NAME = "kernel"
class MyLSTMCell(tf.contrib.rnn.LSTMCell):
def __init__(self, num_units,
use_peepholes=False, cell_clip=None,
initializer=None, num_proj=None, proj_clip=None,
num_unit_shards=None, num_proj_shards=None,
forget_bias=1.0, state_is_tuple=True,
activation=None, reuse=None, name=None, var_given=False, kernel=None, bias=None):
super(MyLSTMCell, self).__init__(num_units,
use_peepholes=use_peepholes, cell_clip=cell_clip,
initializer=initializer, num_proj=num_proj, proj_clip=proj_clip,
num_unit_shards=num_unit_shards, num_proj_shards=num_proj_shards,
forget_bias=forget_bias, state_is_tuple=state_is_tuple,
activation=activation, reuse=reuse, name=name)
self.var_given = var_given
if self.var_given:
self._kernel = kernel
self._bias = bias
def build(self, inputs_shape):
if inputs_shape[1].value is None:
raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"
% inputs_shape)
input_depth = inputs_shape[1].value
h_depth = self._num_units if self._num_proj is None else self._num_proj
maybe_partitioner = (
partitioned_variables.fixed_size_partitioner(self._num_unit_shards)
if self._num_unit_shards is not None
else None)
if self.var_given:
# self._kernel and self._bais are already added in init
pass
else:
self._kernel = self.add_variable(
_WEIGHTS_VARIABLE_NAME,
shape=[input_depth + h_depth, 4 * self._num_units],
initializer=self._initializer,
partitioner=maybe_partitioner)
self._bias = self.add_variable(
_BIAS_VARIABLE_NAME,
shape=[4 * self._num_units],
initializer=init_ops.zeros_initializer(dtype=self.dtype))
if self._use_peepholes:
self._w_f_diag = self.add_variable("w_f_diag", shape=[self._num_units],
initializer=self._initializer)
self._w_i_diag = self.add_variable("w_i_diag", shape=[self._num_units],
initializer=self._initializer)
self._w_o_diag = self.add_variable("w_o_diag", shape=[self._num_units],
initializer=self._initializer)
if self._num_proj is not None:
maybe_proj_partitioner = (
partitioned_variables.fixed_size_partitioner(self._num_proj_shards)
if self._num_proj_shards is not None
else None)
self._proj_kernel = self.add_variable(
"projection/%s" % _WEIGHTS_VARIABLE_NAME,
shape=[self._num_units, self._num_proj],
initializer=self._initializer,
partitioner=maybe_proj_partitioner)
self.built = True
所以代码将像这样:
kernel = get_variable(...)
bias = get_variable(...)
lstm_fw = MyLSTMCell(....., var_given=True, kernel=kernel, bias=bias)