混合多个tf.data.Dataset?

时间:2018-04-24 06:07:27

标签: tensorflow

我有三个数据集D1,D2,D3,它们输出相同类型的数据。我要做的是从一个独特的管道中随机输出D1或D2或D3。我尝试使用tf.data.Dataset.zip((D1, D2, D3)),但后来我不知道如何压平它的输出以便将其洗牌然后输出D1_element, D3_element,D1_element , D2_element ... 这是一个小例子:

import tensorflow as tf

D1 = tf.data.Dataset.range(1,5)
D2 = tf.data.Dataset.range(5,10)
D3 = tf.data.Dataset.range(10,15)

zip = tf.data.Dataset.zip((D1,D2,D2))
...

1 个答案:

答案 0 :(得分:1)

如果有兴趣的话,我找到了以下解决方案:

import tensorflow as tf

def stack(*inputs):
    return tf.stack(inputs)

D1 = tf.data.Dataset.range(1,5)
D2 = tf.data.Dataset.range(5,10)
D3 = tf.data.Dataset.range(10,15)

D = tf.data.Dataset.zip((D1,D2,D3))
D = D.map(stack)
D = D.apply(tf.contrib.data.unbatch())
D = D.shuffle(10, seed=0)
D = D.batch(3)
D = D.prefetch(1)

it = D.make_one_shot_iterator()
next_element = it.get_next()

with tf.Session() as sess:
    print sess.run(next_element)