我正在处理必须传递给RNN的长序列数据。要做截断的BPTT和批处理,似乎有两种选择:
我遇到了tf.contrib.training.batch_sequences_with_states
,似乎正在做两个中的一个。文档让我感到困惑,因此我想确定它生成批次的方式。
我的猜测是第一种方式。这是因为,如果批处理是以第二种方式完成的,那么我们就无法利用矢量化的好处,因为,为了保持一个分段的最后一个步骤到下一个分段的第一个时间步之间的状态,RNN应该处理一个令牌一次一个。
问题:
tf.contrib.training.batch_sequences_with_states
中实现了这两种批处理策略中的哪一种?
答案 0 :(得分:2)
tf.contrib.training.batch_sequences_with_states
实现了以前的行为。每个小批量条目是来自不同序列的一个段(每个序列,可以由可变数量的段组成,具有唯一的密钥,并且该密钥被传递到batch_sequences_with_states
)。与state_saving_rnn
一起使用时,每个段的最终状态将保存回一个特殊的存储容器,该容器允许给定序列的下一个段在下一个sess.run
运行。最后的片段释放了一个不同序列的小批量插槽。