我在不同长度的批次序列上训练LSTM细胞。 tf.nn.rnn
具有非常方便的参数sequence_length
,但在调用之后,我不知道如何选择与批次中每个项目的最后一个步骤相对应的输出行。
我的代码基本如下:
lstm_cell = tf.nn.rnn_cell.LSTMCell(num_lstm_units, input_size)
lstm_outputs, state = tf.nn.rnn(lstm_cell, input_list, dtype=tf.float32, sequence_length=sequence_lengths)
lstm_outputs
是每个时间步都有LSTM输出的列表。但是,我的批次中的每个项目都有不同的长度,因此我想创建一个张量,其中包含对我批次中每个项目有效的最后一个LSTM输出。
如果我可以使用numpy索引,我会做这样的事情:
all_outputs = tf.pack(lstm_outputs)
last_outputs = all_outputs[sequence_lengths, tf.range(batch_size), :]
但事实证明,开始张量流并不支持它(我知道feature request)。
那么,我怎么能得到这些值?
答案 0 :(得分:5)
danijar在我在问题中链接的功能请求页面上发布了一个更可接受的解决方法。它不需要评估张量,这是一个很大的优点。
我让它与tensorflow 0.8一起工作。这是代码:
def extract_last_relevant(outputs, length):
"""
Args:
outputs: [Tensor(batch_size, output_neurons)]: A list containing the output
activations of each in the batch for each time step as returned by
tensorflow.models.rnn.rnn.
length: Tensor(batch_size): The used sequence length of each example in the
batch with all later time steps being zeros. Should be of type tf.int32.
Returns:
Tensor(batch_size, output_neurons): The last relevant output activation for
each example in the batch.
"""
output = tf.transpose(tf.pack(outputs), perm=[1, 0, 2])
# Query shape.
batch_size = tf.shape(output)[0]
max_length = int(output.get_shape()[1])
num_neurons = int(output.get_shape()[2])
# Index into flattened array as a workaround.
index = tf.range(0, batch_size) * max_length + (length - 1)
flat = tf.reshape(output, [-1, num_neurons])
relevant = tf.gather(flat, index)
return relevant
答案 1 :(得分:2)
这不是最好的解决方案,但你可以评估你的输出然后只是使用numpy索引来获得结果并从中创建一个张量变量?它可能起到一个停止间隙,直到tensorflow获得此功能。 e.g。
all_outputs = session.run(lstm_outputs, feed_dict={'your inputs'})
last_outputs = all_outputs[sequence_lengths, tf.range(batch_size), :]
use_this_as_an_input_to_new_tensorflow_op = tf.constant(last_outputs)
答案 2 :(得分:1)
如果您只对最后一个有效输出感兴趣,可以通过LSTMStateTuple
返回的状态检索它,因为它总是一个元组(c,h),其中c是最后一个state和h是最后一个输出。当状态为lstm_cell = tf.nn.rnn_cell.LSTMCell(num_lstm_units, input_size)
lstm_outputs, state = tf.nn.rnn(lstm_cell, input_list, dtype=tf.float32, sequence_length=sequence_lengths)
last_output = state[1]
时,您可以使用以下代码段(在tensorflow 0.12中工作):
{{1}}