如何在tensorflow中将static_rnn输入转换为dynamic_rnn输入?

时间:2017-04-13 15:36:43

标签: tensorflow

我无法理解tensorflow dynamic_rnn的输入参数。如果我能理解如何将static_rnn输入转换为dynamic_rnn输入,那将会有很大帮助。

对于static_rnn,输入应该是长度为T的张量列表,其形状为[batch_size, input_size],其中T是序列长度。这对我来说很有意义。

对于dynamic_rnn,输入应该是形状张量[batch_size, max_time, ...]。我不明白如何在这里加入input_size。更一般地说,我不知道你还可以在省略号中添加什么。

例如,假设我的数据由50个字符长的句子组成,因此input_size是字母表中的字母数。对于static_rnn,我会制作长度为50的张量,其形状为[batch_size, input_size]。如何将此张量列表转换为单个张量,以便我可以将其提供给dynamic_rnn

1 个答案:

答案 0 :(得分:8)

您的dynamic_rnn输入应该是[batch_size, sequence_length, input_size]的形状。

基本上张量代表长度为batch_size的{​​{1}}个例子,省略号中剩下的是单个序列元素的形状。

问题是,sequence_length您不需要事先知道sequence_length,因此您的输入占位符可能看起来像

dynamic_rnn

这非常方便。此外,单个批处理中的示例可以具有不同的长度(但必须填充到相同的长度),但您必须将x = tf.placeholder(tf.int32, shape=(batch_size, None, input_size)) 参数传递给sequence_length,以便它知道何时停止每个示例的计算。