在Tensorflow数据集API

时间:2018-07-06 17:53:35

标签: python tensorflow tensorflow-datasets

我正在尝试使用dataset api加载数据,发现我花了大部分时间将数据加载到shuffle缓冲区中。我该如何优化此管道,以最大程度地减少填充混洗缓冲区所花费的时间。

(tf.data.Dataset.list_files(path)
   .shuffle(num_files)  # number of tfrecord files 
   .apply(tf.contrib.data.parallel_interleave(lambda f: tf.data.TFRecordDataset(f), cycle_length=num_files))
   .shuffle(num_items)  # number of images in the dataset
   .map(parse_func, num_parallel_calls=8)
   .map(get_patches, num_parallel_calls=8)
   .apply(tf.contrib.data.unbatch())
   # Patch buffer is currently the number of patches extracted per image
   .apply(tf.contrib.data.shuffle_and_repeat(patch_buffer))
   .batch(64)
   .prefetch(1)
   .make_one_shot_iterator())

1 个答案:

答案 0 :(得分:2)

由于我最多拥有数千个图像,因此,针对此问题的解决方案是每个图像有一个单独的tfrecord文件。这样一来,无需先将单个图像加载到内存中即可对各个图像进行混洗。这大大减少了需要发生的缓冲。