TensorFlow:在每个时间步打印RNN的内部状态

时间:2018-05-21 11:22:00

标签: python debugging tensorflow lstm

我正在使用tf.nn.dynamic_rnn类来创建LSTM。我已经在一些数据上训练了这个模型,现在我想在每次提供输入时检查这个训练过的LSTM隐藏状态的值是什么。

在SO和TensorFlow的GitHub页面上进行了一些挖掘之后,我看到有些人提到我应该编写自己的LSTM单元格,它返回我想要打印的任何内容,作为output的一部分。 LSTM。然而,这对我来说似乎并不直接,因为隐藏状态和LSTM的输出形状不同。

来自LSTM的输出张量具有形状[16, 1],隐藏状态是形状[16, 16]的张量。连接它们会导致形状张量[16, 17]。当我试图返回它时,我得到一个错误,说某些TensorFlow操作需要一个形状张量[16,1]

有没有人知道更容易解决这种情况?我想知道是否可以使用tf.Print来打印所需的张量。

1 个答案:

答案 0 :(得分:0)

好的,问题是我正在修改输出但是没有更新LSTM本身的output_size。因此错误。它现在完美无缺。但是,我仍然觉得这种方法非常烦人。不接受我自己的答案,希望有人能有更清洁的解决方案。