我正在尝试使用tf.data.Dataset
API建立输入管道。下面是一个非常简化的MWE,我正在其中尝试对不同大小的张量进行批处理。我知道一个事实,即批处理要求所有张量都具有相同的形状。但是,我有兴趣在第一维中以串联形式进行批处理,而不是创建新维。 tf.data.Dataset.padded_batch
对我来说不是一个选择,因为它会带来大量的不必要的开销。
import tensorflow as tf
if __name__ == '__main__':
dataset = tf.data.Dataset.range(10)
dataset = dataset.map(lambda x: tf.range(0, x + 1))
dataset = dataset.batch(1) # Batching using concatenation?
iterator = dataset.make_one_shot_iterator()
batch = iterator.get_next()
with tf.Session() as sess:
while True:
try:
print(sess.run(batch))
except tf.errors.OutOfRangeError:
break
输出:
[[0]]
[[0 1]]
[[0 1 2]]
[[0 1 2 3]]
我要实现的目标(例如,批量大小设置为2)
[0 0 1]
[0 1 2 0 1 2 3]
谢谢!