Tensorflow dynamic_rnn参数含义

时间:2017-01-27 00:48:19

标签: tensorflow

我正在努力理解神秘的RNN文档。任何有关以下的帮助将不胜感激。

tf.nn.dynamic_rnn(cell, inputs, sequence_length=None, initial_state=None, dtype=None, parallel_iterations=None, swap_memory=False, time_major=False, scope=None)

我很难理解这些参数如何与数学LSTM方程和RNN定义相关。单元格展开的大小在哪里?它是由' max_time'定义的。输入的维度? batch_size只是分割长数据的便利还是与minibatch SGD相关?输出状态是否通过批次传递?

1 个答案:

答案 0 :(得分:15)

tf.nn.dynamic_rnn接收不相关序列的批次(具有小批量意义)。

  • cell是您要使用的实际单元格(LSTM,GRU,...)
  • inputs的形状为batch_size x max_time x input_size,其中max_time是最长序列中的步数(但所有序列的长度可能相同)
  • sequence_length是一个大小为batch_size的向量,其中每个元素都给出了批处理中每个序列的长度(如果所有序列的大小相同,则将其保留为默认值。这参数是定义单元格展开大小的参数。

隐藏状态处理

处理隐藏状态的常用方法是在dynamic_rnn之前定义一个初始状态张量,例如:

hidden_state_in = cell.zero_state(batch_size, tf.float32) 
output, hidden_state_out = tf.nn.dynamic_rnn(cell, 
                                             inputs,
                                             initial_state=hidden_state_in,
                                             ...)

在上面的代码段中,hidden_state_inhidden_state_out具有相同的形状[batch_size, ...]实际形状取决于您使用的单元格类型,但重要的是第一个维度是批量大小)。

这样,dynamic_rnn每个序列都有一个初始隐藏状态。 它将自己的inputs参数中的每个序列从时间步长传递到隐藏状态hidden_state_out将包含每个序列的最终输出状态在批次中。在同一批次的序列之间不传递隐藏状态,但仅在相同序列的时间步长之间传递。

我什么时候需要手动反馈隐藏状态?

通常,当您进行培训时,每个批次都不相关,因此您在执行session.run(output)时无需反馈隐藏状态。

但是,如果您正在测试,并且您需要在每个时间步骤输出,(即您必须在每个时间步执行session.run()),您将需要评估并反馈隐藏的输出国家使用这样的东西:

output, hidden_state = sess.run([output, hidden_state_out],
                                feed_dict={hidden_state_in:hidden_state})

否则tensorflow将在每个时间步使用默认cell.zero_state(batch_size, tf.float32),这相当于在每个时间步重新初始化隐藏状态。