我有三个数据集D1,D2,D3,它们输出相同类型的数据。我要做的是从一个独特的管道中随机输出D1或D2或D3。我尝试使用tf.data.Dataset.zip((D1, D2, D3))
,但后来我不知道如何压平它的输出以便将其洗牌然后输出D1_element, D3_element,D1_element , D2_element ...
这是一个小例子:
import tensorflow as tf
D1 = tf.data.Dataset.range(1,5)
D2 = tf.data.Dataset.range(5,10)
D3 = tf.data.Dataset.range(10,15)
zip = tf.data.Dataset.zip((D1,D2,D2))
...
答案 0 :(得分:1)
如果有兴趣的话,我找到了以下解决方案:
import tensorflow as tf
def stack(*inputs):
return tf.stack(inputs)
D1 = tf.data.Dataset.range(1,5)
D2 = tf.data.Dataset.range(5,10)
D3 = tf.data.Dataset.range(10,15)
D = tf.data.Dataset.zip((D1,D2,D3))
D = D.map(stack)
D = D.apply(tf.contrib.data.unbatch())
D = D.shuffle(10, seed=0)
D = D.batch(3)
D = D.prefetch(1)
it = D.make_one_shot_iterator()
next_element = it.get_next()
with tf.Session() as sess:
print sess.run(next_element)