我正在将旧的数据层(使用队列)切换到“新”和推荐的数据集API。我是第一次使用它,所以我提供代码示例以防万一我遇到了根本性错误。
我从生成器创建数据集(将读取文件,并提供n个样本)。这是一个小数据集和n_iterations>> n_samples,所以我只是想一遍又一遍地阅读这个数据集,理想的是洗牌。
sample_set = tf.data.Dataset.from_generator( data_generator(filename),
(tf.uint8, tf.uint8), (tf.TensorShape([256,256,4]), tf.TensorShape([256,256,1]))
)
使用datagenerator:
class data_generator:
def __init__(self, filename):
self.filename= filename
def __call__(self):
with filename.open() as f:
for idx in f: yield img[idx], label[idx]
为了实际使用数据,我得到了我需要定义Iterator
sample = sample_set.make_one_shot_iterator().get_next()
然后我们设置读取数据
while True:
try: my_sample = sess.run(sample)
except tf.errors.OutOfRangeError: break # this happens after dset is read once
但是所有可用的迭代器似乎都是“有限的”,就像他们只读取一次数据集一样。
是否有一种简单的方法可以从数据集中无休止地阅读?
答案 0 :(得分:3)
BUF_SIZE = 100 # choose it depending on your data
sample_set = tf.data.Dataset.from_generator( data_generator(filename),
(tf.uint8, tf.uint8), (tf.TensorShape([256,256,4]),
tf.TensorShape([256,256,1]))
).repeat().shuffle(BUF_SIZE)
答案 1 :(得分:1)
Dataset.repeat()
转换会在您不通过明确count
的情况下无休止地重复数据集:
sample_set = tf.data.Dataset.from_generator(
data_generator(filename), (tf.uint8, tf.uint8),
(tf.TensorShape([256,256,4]), tf.TensorShape([256,256,1])))
# Repeats `sample_set` endlessly.
sample_set = sample_set.repeat()
sample = sample_set.make_one_shot_iterator().get_next()
答案 2 :(得分:0)
可重新初始化的迭代器可以重新初始化同一个数据集,因此这段代码会一遍又一遍地读取相同的数据集:
sample = tf.data.Iterator.from_structure(sample_set.output_types,
sample_set.output_shapes).get_next()
sample_it.make_initializer(sample_set) # create initialize op
with tf.Session(config=config) as sess:
sess.run(sample_set_init_op) # initialize in the beginning
while True:
try:
my_sample = sess.run(sample)
except tf.errors.OutOfRangeError:
sess.run(sample_set_init_op) # re-initialize on same dataset