我使用tf.PaddingFIFOQueue
或tf.contrib.data.PaddedBatchDataset
来提供不同长度和dequeue_many
的序列,以获得零填充批量。
是否有一些通用的方法来获得该批次的序列长度?
我目前的解决方案是明确提供序列长度作为队列的附加输入,即我有tf.PaddingFIFOQueue(names=["data", "seq_length"], ...)
。我也可以使用tf.ones_like()
,但我目前的方式似乎更便宜,更简单。但我想知道这是规范/标准的方式,还是有其他方式。
答案 0 :(得分:0)
您可以将data
和seq_length
合并为一个元组(或列表),然后将元组推入队列。
import tensorflow as tf
sess = tf.InteractiveSession()
q = tf.PaddingFIFOQueue(capacity=10, dtypes=[tf.int32, tf.int32], shapes=[[], [None]])
eq1 = q.enqueue([1, [1]])
eq2 = q.enqueue([2, [2,3]])
eq3 = q.enqueue([3, [4,5,6]])
dq = q.dequeue()
sess.run(eq1)
sess.run(eq2)
sess.run(eq3)
sess.run(dq) # [1, array([1], dtype=int32)]
sess.run(dq) # [2, array([2, 3], dtype=int32)]
sess.run(dq) # [3, array([4, 5, 6], dtype=int32)]