来自生成器的数据集,一次生成多个元素

时间:2018-11-27 13:21:00

标签: python tensorflow tensorflow-datasets

我正在测试是否需要从不推荐使用的基于队列的API迁移到TensorFlow中的数据集API的地方。

我似乎找不到与之等效的一个用例是enqueue_many的{​​{1}}参数。

特别是我想创建一个Python生成器,它可以产生“批处理”数组,其中“批处理大小”不一定与用于SGD训练更新的数组相同,然后将批处理应用于该数据流(即与tf.train.batch中的enqueue_many一样。

在新的数据集API中是否有解决方法?

1 个答案:

答案 0 :(得分:0)

尝试使用平面图

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
n_reads=10
read_batch_size=20
training_batch_size = 2

def mnist_gen():
    mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
    for i in range(n_reads):
        batch_x, batch_y = mnist.train.next_batch(read_batch_size)
        # Yielding a batch instead of single record
        yield batch_x,batch_y
data = tf.data.Dataset.from_generator(mnist_gen,output_types=(tf.float32,tf.float32))
data = data.flat_map(lambda *x: tf.data.Dataset.zip(tuple(map(tf.data.Dataset.from_tensor_slices,x)))).batch(training_batch_size)
# if u yield only batch_x change lambda function to data.flat_map(lambda x: tf.data.Dataset.from_tensor_slices(x)))
iter = data.make_one_shot_iterator()
next_item = iter.get_next()

X= next_item[0]
Y = next_item[1]

with tf.Session() as sess:
    for i in range(n_reads*read_batch_size // training_batch_size):
        print(i, sess.run(X))