tensorflow的tf.contrib.training.batch_sequences_with_states API如何工作?

时间:2017-10-31 12:49:10

标签: python tensorflow deep-learning lstm rnn

我正在处理必须传递给RNN的长序列数据。要做截断的BPTT和批处理,似乎有两种选择:

  1. 通过组合来自不同序列的各个段来创建批处理。保留批次中每个序列的最终状态,并将其传递给下一批。
  2. 将每个序列视为一个小批处理,序列中的段成为批处理的成员。保留一个段中最后一个时间步的状态,并将其传递给下一个段的第一个时间步。
  3. 我遇到了tf.contrib.training.batch_sequences_with_states,似乎正在做两个中的一个。文档让我感到困惑,因此我想确定它生成批次的方式。

    我的猜测是第一种方式。这是因为,如果批处理是以第二种方式完成的,那么我们就无法利用矢量化的好处,因为,为了保持一个分段的最后一个步骤到下一个分段的第一个时间步之间的状态,RNN应该处理一个令牌一次一个。

    问题:

    tf.contrib.training.batch_sequences_with_states中实现了这两种批处理策略中的哪一种?

1 个答案:

答案 0 :(得分:2)

tf.contrib.training.batch_sequences_with_states实现了以前的行为。每个小批量条目是来自不同序列的一个段(每个序列,可以由可变数量的段组成,具有唯一的密钥,并且该密钥被传递到batch_sequences_with_states)。与state_saving_rnn一起使用时,每个段的最终状态将保存回一个特殊的存储容器,该容器允许给定序列的下一个段在下一个sess.run运行。最后的片段释放了一个不同序列的小批量插槽。