我有一个LSTMStateTuple
对象的python列表,我必须使用张量作为索引来检索它们。例如:
index = tf.constant(0)
lstm = tf.nn.rnn_cell.LSTMCell(128)
states = [lstm.zero_state(10, tf.float32), lstm.zero_state(10, tf.float32)]
如果我尝试state = states[index]
会出错,并且state = tf.gather(states, index)
将states
转换为张量并返回形状为[10, 2, 128]
的张量。
如何获得LSTMStateTuple
而不是张量?当我将状态传递给lstm时,我想避免从LSTMStateTuple
列表到张量的转换以及从张量到LSTMStateTuple
的转换。
答案 0 :(得分:0)
您创建了两个状态,并将它们置于LSTMStateTuple
中。
cell_state = tf.placeholder(tf.float32, [batch_size, state_size])
hidden_state = tf.placeholder(tf.float32, [batch_size, state_size])
init_state = tf.nn.rnn_cell.LSTMStateTuple(cell_state, hidden_state)