使用tf.data.Dataset批量/串联不同大小的张量

时间:2018-07-21 17:10:41

标签: python tensorflow dataset

我正在尝试使用tf.data.Dataset API建立输入管道。下面是一个非常简化的MWE,我正在其中尝试对不同大小的张量进行批处理。我知道一个事实,即批处理要求所有张量都具有相同的形状。但是,我有兴趣在第一维中以串联形式进行批处理,而不是创建新维。 tf.data.Dataset.padded_batch对我来说不是一个选择,因为它会带来大量的不必要的开销。

import tensorflow as tf

if __name__ == '__main__':
    dataset = tf.data.Dataset.range(10)
    dataset = dataset.map(lambda x: tf.range(0, x + 1))
    dataset = dataset.batch(1)  # Batching using concatenation?
    iterator = dataset.make_one_shot_iterator()

    batch = iterator.get_next()

    with tf.Session() as sess:
        while True:
            try:
                print(sess.run(batch))
            except tf.errors.OutOfRangeError:
                break

输出:

  

[[0]]
  [[0 1]]
  [[0 1 2]]
  [[0 1 2 3]]

我要实现的目标(例如,批量大小设置为2)

  

[0 0 1]
  [0 1 2 0 1 2 3]

谢谢!

0 个答案:

没有答案