如何从Tensorflow tf.data.Dataset无休止地阅读?

时间:2018-03-19 15:33:06

标签: tensorflow tensorflow-datasets

我正在将旧的数据层(使用队列)切换到“新”和推荐的数据集API。我是第一次使用它,所以我提供代码示例以防万一我遇到了根本性错误。

我从生成器创建数据集(将读取文件,并提供n个样本)。这是一个小数据集和n_iterations>> n_samples,所以我只是想一遍又一遍地阅读这个数据集,理想的是洗牌。

sample_set = tf.data.Dataset.from_generator( data_generator(filename),  
    (tf.uint8, tf.uint8), (tf.TensorShape([256,256,4]), tf.TensorShape([256,256,1]))
)

使用datagenerator:

class data_generator:
    def __init__(self, filename):
        self.filename= filename

    def __call__(self):
        with filename.open() as f:
           for idx in f: yield img[idx], label[idx]

为了实际使用数据,我得到了我需要定义Iterator

sample = sample_set.make_one_shot_iterator().get_next()

然后我们设置读取数据

while True:
    try: my_sample = sess.run(sample)
    except tf.errors.OutOfRangeError: break   # this happens after dset is read once

但是所有可用的迭代器似乎都是“有限的”,就像他们只读取一次数据集一样。

是否有一种简单的方法可以从数据集中无休止地阅读?

3 个答案:

答案 0 :(得分:3)

数据集包含repeatshuffle方法。

BUF_SIZE = 100 # choose it depending on your data
sample_set = tf.data.Dataset.from_generator( data_generator(filename),  
    (tf.uint8, tf.uint8), (tf.TensorShape([256,256,4]), 
    tf.TensorShape([256,256,1]))
).repeat().shuffle(BUF_SIZE)

答案 1 :(得分:1)

Dataset.repeat()转换会在您不通过明确count的情况下无休止地重复数据集:

sample_set = tf.data.Dataset.from_generator(
    data_generator(filename), (tf.uint8, tf.uint8),
    (tf.TensorShape([256,256,4]), tf.TensorShape([256,256,1])))

# Repeats `sample_set` endlessly.
sample_set = sample_set.repeat()

sample = sample_set.make_one_shot_iterator().get_next()

答案 2 :(得分:0)

可重新初始化的迭代器可以重新初始化同一个数据集,因此这段代码会一遍又一遍地读取相同的数据集:

sample = tf.data.Iterator.from_structure(sample_set.output_types,
                                         sample_set.output_shapes).get_next()

sample_it.make_initializer(sample_set)     # create initialize op

with tf.Session(config=config) as sess:
    sess.run(sample_set_init_op)           # initialize in the beginning

    while True:
        try: 
             my_sample = sess.run(sample)
        except tf.errors.OutOfRangeError:
             sess.run(sample_set_init_op)  # re-initialize on same dataset