我正在尝试使用dataset
api加载数据,发现我花了大部分时间将数据加载到shuffle缓冲区中。我该如何优化此管道,以最大程度地减少填充混洗缓冲区所花费的时间。
(tf.data.Dataset.list_files(path)
.shuffle(num_files) # number of tfrecord files
.apply(tf.contrib.data.parallel_interleave(lambda f: tf.data.TFRecordDataset(f), cycle_length=num_files))
.shuffle(num_items) # number of images in the dataset
.map(parse_func, num_parallel_calls=8)
.map(get_patches, num_parallel_calls=8)
.apply(tf.contrib.data.unbatch())
# Patch buffer is currently the number of patches extracted per image
.apply(tf.contrib.data.shuffle_and_repeat(patch_buffer))
.batch(64)
.prefetch(1)
.make_one_shot_iterator())
答案 0 :(得分:2)
由于我最多拥有数千个图像,因此,针对此问题的解决方案是每个图像有一个单独的tfrecord文件。这样一来,无需先将单个图像加载到内存中即可对各个图像进行混洗。这大大减少了需要发生的缓冲。