我想计算相对于二维LSTM中所有隐藏状态的损失函数的梯度,这是使用TensorArray
和tf.while_loop
实现的(参考this repo实施细节)。基本思想是我们有一个TensorArray
(长度为time_steps
)来存储所有隐藏状态,我们使用tf.while_loop
来逐步计算隐藏状态。在while循环体中,我们根据h_t
计算h_t-1
,如:
h_t-1 = h_tensor_array.read(t-1) # read previous hidden states from TensorArray
x_t = x_tensor_array.read(t) # read current input from TensorArray
o_t, h_t = cell(x_t, h_t-1) # RNN cell
h_tensor_array.write(t, h_t) # update hidden states TensorArray
o_tensor_array.write(t, o_t) # update output TensorArray
如果我只是使用
all_hidden_states = h_tensor_array.stack() # get all hidden states Tensor
tf.gradients(loss, all_hidden_states) # None
(损失根据最终隐藏状态计算),梯度为None
。我怀疑每次调用stack
的{{1}}或read
方法时,返回的TensorArray
不是完全相同的实例,意味着他们(Tensor
和h_t-1
)在计算图中是不同的节点。由于我们在all_hidden_states
中使用h_t-1
,因此cell
和loss
之间存在路径,但h_t-1
和loss
之间不存在路径。有没有人知道如何获得所有隐藏状态的渐变?谢谢!