如何在tensorflow中修改BasicLSTMcell中的初始值

时间:2018-01-24 08:21:12

标签: python tensorflow lstm recurrent-neural-network

我想用Tensorflow中的BasicLSTMcell初始化重量和偏差值,并使用我预先训练的值(我通过.npy得到它们)。但是当我使用get_tensor_by_name来获取张量时,它似乎只是为我返回一个副本,并且原始值永远不会改变。我需要你的帮助!

2 个答案:

答案 0 :(得分:0)

在第一次迭代中添加一些tf.assign操作,将您想要的值分配给内部变量。

确保这只发生在第一次迭代中,否则你将覆盖你所做的任何训练。

如果您愿意,单元格有一个名为get_trainable_variables的方法可以帮助您。

答案 1 :(得分:0)

我认为BasicRNNCell 太基本。创建自己的单元格会更容易,BasicRNNCell来自kernel_initializerbias_initializer

import tensorflow as tf
from tensorflow.python.ops.rnn_cell_impl import _Linear

class MyNNCell(tf.nn.rnn_cell.BasicRNNCell):
  def __init__(self, num_units, activation=None, reuse=None,
               kernel_initializer=None,
               bias_initializer=None):
    super(MyNNCell, self).__init__(num_units, activation, reuse)
    self._kernel_initializer = kernel_initializer
    self._bias_initializer = bias_initializer

  def call(self, inputs, state):
    if self._linear is None:
      self._linear = _Linear([inputs, state], self._num_units, True,
                             bias_initializer=self._bias_initializer,
                             kernel_initializer=self._kernel_initializer)
    return super(MyNNCell, self).call(inputs, state)

这可能也有帮助:tf.nn.static_rnntf.nn.dynamic_rnn都有initial_state参数。它没有设置内核和单元的偏差,而是设置它在第一次调用时要接收的状态。由于BasicLSTMcell中的内核和偏差始终用零初始化,因此它是预设单元状态的等效方法。