如何用张量索引LSTMStateTuple列表?

时间:2018-10-12 23:26:31

标签: python tensorflow lstm

我有一个LSTMStateTuple对象的python列表,我必须使用张量作为索引来检索它们。例如:

index = tf.constant(0)
lstm = tf.nn.rnn_cell.LSTMCell(128)
states = [lstm.zero_state(10, tf.float32), lstm.zero_state(10, tf.float32)]

如果我尝试state = states[index]会出错,并且state = tf.gather(states, index)states转换为张量并返回形状为[10, 2, 128]的张量。

如何获得LSTMStateTuple而不是张量?当我将状态传递给lstm时,我想避免从LSTMStateTuple列表到张量的转换以及从张量到LSTMStateTuple的转换。

1 个答案:

答案 0 :(得分:0)

您创建了两个状态,并将它们置于LSTMStateTuple中。

cell_state = tf.placeholder(tf.float32, [batch_size, state_size])
hidden_state = tf.placeholder(tf.float32, [batch_size, state_size])
init_state = tf.nn.rnn_cell.LSTMStateTuple(cell_state, hidden_state)