从PaddingFIFOQueue获取动态序列长度

时间:2016-11-18 11:13:43

标签: tensorflow

我使用tf.PaddingFIFOQueuetf.contrib.data.PaddedBatchDataset来提供不同长度和dequeue_many的序列,以获得零填充批量。

是否有一些通用的方法来获得该批次的序列长度?

我目前的解决方案是明确提供序列长度作为队列的附加输入,即我有tf.PaddingFIFOQueue(names=["data", "seq_length"], ...)。我也可以使用tf.ones_like(),但我目前的方式似乎更便宜,更简单。但我想知道这是规范/标准的方式,还是有其他方式。

1 个答案:

答案 0 :(得分:0)

您可以将dataseq_length合并为一个元组(或列表),然后将元组推入队列。

import tensorflow as tf
sess = tf.InteractiveSession()
q = tf.PaddingFIFOQueue(capacity=10, dtypes=[tf.int32, tf.int32], shapes=[[], [None]])
eq1 = q.enqueue([1, [1]])
eq2 = q.enqueue([2, [2,3]])
eq3 = q.enqueue([3, [4,5,6]])
dq = q.dequeue()
sess.run(eq1)
sess.run(eq2)
sess.run(eq3)
sess.run(dq)  # [1, array([1], dtype=int32)]
sess.run(dq)  # [2, array([2, 3], dtype=int32)]
sess.run(dq)  # [3, array([4, 5, 6], dtype=int32)]