如何使用TensorArray计算RNN隐藏状态的渐变

时间:2018-04-30 15:04:09

标签: tensorflow rnn

我想计算相对于二维LSTM中所有隐藏状态的损失函数的梯度,这是使用TensorArraytf.while_loop实现的(参考this repo实施细节)。基本思想是我们有一个TensorArray(长度为time_steps)来存储所有隐藏状态,我们使用tf.while_loop来逐步计算隐藏状态。在while循环体中,我们根据h_t计算h_t-1,如:

h_t-1 = h_tensor_array.read(t-1) # read previous hidden states from TensorArray
x_t = x_tensor_array.read(t) # read current input from TensorArray
o_t, h_t = cell(x_t, h_t-1) # RNN cell
h_tensor_array.write(t, h_t) # update hidden states TensorArray
o_tensor_array.write(t, o_t) # update output TensorArray

如果我只是使用

all_hidden_states = h_tensor_array.stack() # get all hidden states Tensor
tf.gradients(loss, all_hidden_states) # None

(损失根据最终隐藏状态计算),梯度为None。我怀疑每次调用stack的{​​{1}}或read方法时,返回的TensorArray 不是完全相同的实例,意味着他们(Tensorh_t-1)在计算图中是不同的节点。由于我们在all_hidden_states中使用h_t-1,因此cellloss之间存在路径,但h_t-1loss之间不存在路径。有没有人知道如何获得所有隐藏状态的渐变?谢谢!

0 个答案:

没有答案