在tensorflow中获取dynamic_rnn的最后一个输出?

时间:2016-04-23 23:25:13

标签: python tensorflow

我正在使用dynamic_rnn来处理MNIST数据:

# LSTM Cell
lstm = rnn_cell.LSTMCell(num_units=200,
                         forget_bias=1.0,
                         initializer=tf.random_normal)

# Initial state
istate = lstm.zero_state(batch_size, "float")

# Get lstm cell output
output, states = rnn.dynamic_rnn(lstm, X, initial_state=istate)

# Output at last time point T
output_at_T = output[:, 27, :]

完整代码:http://pastebin.com/bhf9MgMe

lstm的输入是(batch_size, sequence_length, input_size)

因此output_at_T的维度为(batch_size, sequence_length, num_units) num_units=200

我需要获取sequence_length维度的最后一个输出。在上面的代码中,这是硬编码为27。但是,我事先并不知道sequence_length,因为它可以在我的应用程序中从批处理更改为批处理。

我试过了:

output_at_T = output[:, -1, :]

但它说负面索引尚未实现,我尝试使用占位符变量和常量(我可以理想地为特定批次提供sequence_length);没有工作。

任何在tensorflow atm中实现类似内容的方法吗?

7 个答案:

答案 0 :(得分:13)

您是否注意到dynamic_rnn有两个输出?

  1. 输出1,我们称之为h,每个时间步都有所有输出(即h_1,h_2等),
  2. 输出2,final_state,有两个元素:cell_state,以及批处理中每个元素的最后一个输出(只要您将序列长度输入到dynamic_rnn)。
  3. 所以来自:

    h, final_state= tf.dynamic_rnn( ..., sequence_length=[batch_size_vector], ... )
    

    批次中每个元素的最后一个状态是:

    final_state.h
    

    请注意,这包括当批处理的每个元素的序列长度不同时的情况,因为我们正在使用sequence_length参数。

答案 1 :(得分:5)

这是gather_nd的用途!

def extract_axis_1(data, ind):
    """
    Get specified elements along the first axis of tensor.
    :param data: Tensorflow tensor that will be subsetted.
    :param ind: Indices to take (one for each element along axis 0 of data).
    :return: Subsetted tensor.
    """

    batch_range = tf.range(tf.shape(data)[0])
    indices = tf.stack([batch_range, ind], axis=1)
    res = tf.gather_nd(data, indices)

    return res

在你的情况下(假设sequence_length是一个1-D张量,每个轴的长度为0元素):

output = extract_axis_1(output, sequence_length - 1)

现在输出是维度[batch_size, num_cells]的张量。

答案 2 :(得分:2)

大多数答案都会详尽地介绍它,但此代码片段可能有助于了解dynamic_rnn图层实际返回的内容

=> (输出,final_output_state)的元组。

因此,对于具有最大序列长度的T时间步输出的输入,其形状为[Batch_size, T, num_inputs](给定time_major = False;默认值它包含每个时间步长h1, h2.....hT的输出状态。

final_output_state 的形状为[Batch_size,num_inputs],并且具有每个批次序列的最终单元格状态cT和输出状态hT

但是由于dynamic_rnn被使用,我的猜测是每个批次的序列长度不同。

    import tensorflow as tf
    import numpy as np
    from tensorflow.contrib import rnn
    tf.reset_default_graph()

    # Create input data
    X = np.random.randn(2, 10, 8)

    # The second example is of length 6 
    X[1,6:] = 0
    X_lengths = [10, 6]

    cell = tf.nn.rnn_cell.LSTMCell(num_units=64, state_is_tuple=True)

    outputs, states  = tf.nn.dynamic_rnn(cell=cell,
                                         dtype=tf.float64,
                                         sequence_length=X_lengths,
                                         inputs=X)

    result = tf.contrib.learn.run_n({"outputs": outputs, "states":states},
                                    n=1,
                                    feed_dict=None)
    assert result[0]["outputs"].shape == (2, 10, 64)
    print result[0]["outputs"].shape
    print result[0]["states"].h.shape
    # the final outputs state and states returned must be equal for each      
    # sequence
    assert(result[0]["outputs"][0][-1]==result[0]["states"].h[0]).all()
    assert(result[0]["outputs"][-1][5]==result[0]["states"].h[-1]).all()
    assert(result[0]["outputs"][-1][-1]==result[0]["states"].h[-1]).all()

最终断言将失败,因为第二个序列的最终状态是第6个步骤,即。索引5和[6:9]的其余输出在第二个时间步长都是0

答案 3 :(得分:1)

我是Stackoverflow的新手,无法发表评论所以我正在撰写这个新答案。 @VM_AI,最后一个索引是tf.shape(output)[1] - 1。 所以,重复你的回答:

# Let's first fetch the last index of seq length
# last_index would have a scalar value
last_index = tf.shape(output)[1] - 1
# Then let's reshape the output to [sequence_length,batch_size,num_units]
# for convenience
output_rs = tf.transpose(output,[1,0,2])
# Last state of all batches
last_state = tf.nn.embedding_lookup(output_rs,last_index)

这对我有用。

答案 4 :(得分:1)

output[:, -1, :]

现在与Tensorflow 1.x合作!!

答案 5 :(得分:0)

您应该可以使用output访问tf.shape(output)张量的形状。 tf.shape()函数将返回包含output张量大小的1d张量。在您的示例中,这将是(batch_size, sequence_length, num_units)

然后,您应该能够将output_at_T的值提取为output[:, tf.shape(output)[1], :]

答案 6 :(得分:0)

TensorFlow def array_count9(nums): count = 0 count += 1 if i == 9 else count == count for i in nums return count 中有一个函数可以让您获得形状的符号解释,而不是tf.shape返回的None。在获取最后一个索引之后,您可以使用output._shape[1]进行查找,特别是当要提取的数据很高时,建议使用tf.nn.embedding_lookup进行查找。

32 by default

这应该有用。

只是为了澄清@Benoit Steiner说的话。他的解决方案不起作用,因为tf.shape将返回形状值的符号解释,并且不能用于切片张量,即直接索引