在LSTMStateTuple上调用reshape将其转换为张量

时间:2017-05-25 10:09:35

标签: tensorflow reshape lstm recurrent-neural-network tensor

我正在使用带有LSTMCell的dynamic_rnn,它发出了一个包含内部状态的LSTMStateTuple。调用此对象的重塑(由于我的错误)会导致张量,而不会在图形创建时导致任何错误。通过图表输入输入时,我在运行时没有收到任何错误。

代码:

cell = tf.contrib.rnn.LSTMCell(size, state_is_tuple=True, ...)
outputs, states = tf.nn.dynamic_rnn(cell, inputs, ...)
print(states) # state is an LSTMStateTuple
states = tf.reshape(states, [-1, size])
print(states) # state is a tensor of shape [?, size]

这是一个错误(我问,因为它没有在任何地方记录)?什么是重塑的张量持有?

1 个答案:

答案 0 :(得分:0)

我进行了类似的实验,可能会给你一些提示:

>>> s = tf.constant([[0, 0, 0, 1, 1, 1],
                     [2, 2, 2, 3, 3, 3]])
>>> t = tf.constant([[4, 4, 4, 5, 5, 5],                                                             
                     [6, 6, 6, 7, 7, 7]])
>>> g = tf.reshape((s, t), [-1, 3]) # <tf.Tensor 'Reshape_1:0' shape=(8, 3) dtype=int32>
>>> sess.run(g)
array([[0, 0, 0],
       [1, 1, 1],
       [2, 2, 2],
       [3, 3, 3],
       [4, 4, 4],
       [5, 5, 5],
       [6, 6, 6],
       [7, 7, 7]], dtype=int32)

我们可以看到它只是在第一维中连接两个张量并执行重塑。由于LSTMStateTuple就像一个namedtuple,它与元组具有相同的效果,我认为这也是你的情况。

让我们走得更远,

>>> st = tf.contrib.rnn.LSTMStateTuple(s, t)
>>> gg = tf.reshape(st, [-1, 3])
>>> sess.run(gg)
    array([[0, 0, 0],
           [1, 1, 1],
           [2, 2, 2],
           [3, 3, 3],
           [4, 4, 4],
           [5, 5, 5],
           [6, 6, 6],
           [7, 7, 7]], dtype=int32)

我们可以看到,如果我们创建一个LSTMStateTuple,结果将验证我们的假设。