Tensorflow Dataset API-行为说明

时间:2019-02-13 20:34:40

标签: tensorflow tensorflow-datasets tfrecord

使用下面的代码,我想问一些关于下面到底发生了什么的问题。

dataset = tf.data.TFRecordDataset(filepath)
dataset = dataset.map(parse_function, num_parallel_calls=4)
dataset = dataset.repeat()
dataset = dataset.shuffle(1024)
dataset = dataset.batch(16)
iterator = dataset.make_one_shot_iterator()

1。dataset.map(parse_function, num_parallel_calls=4)-我们在这里加载多少条记录?内存或固定数目可以容纳多少?

2。dataset = dataset.repeat()-我们到底要重复什么?当前从.1点加载的数据?如果是这样,是否意味着我们将不再加载其他文件?

3。随机播放到底如何工作?

4。我们可以在映射之前使用重复,随机播放和批处理并在文件路径上工作,而不是单独使用文件吗?

2 个答案:

答案 0 :(得分:0)

  1. 您正在这里加载整个数据集。在批处理之前应用地图通常不是一个好主意。 Tensorflow对张量大小有2GB的硬限制。 num_parallel_calls表示并行应用的映射函数的数量。
  2. dataset.repeat()没有指定纪元值将无限期重复数据集。
  3. 随机播放将随机播放具有指定缓冲区值的数据集。为了正确地改组,通常最好将此值设置为数据集长度,并在批处理之前应用此功能。
  4. tf.data.TFRecordDataset期望文件名作为输入。通常,首选顺序是

    dataset = dataset.shuffle(shuffle_buffer).repeat()
    dataset = dataset.batch(batch_size)
    dataset = dataset.map(map_func)
    

看看https://www.tensorflow.org/guide/performance/datasets

答案 1 :(得分:0)

  1. Dataset API中的数据是延迟加载的,因此它取决于以后的操作。现在,由于随机缓冲区的大小,您一次可以加载1024个样本。它需要填充随机播放缓冲区。当您从迭代器中获取值时,数据将被延迟加载。
  2. 您重复加载的数据,因为重复是在map函数之后。这就是为什么建议在解析数据之前先进行洗牌,因为它对内存更友好。
  3. 随机播放会加载一些数据(取决于随机播放缓冲区的大小),并随机播放这些数据。
  4. 是的,您可以重复,随机播放然后映射,甚至在performance guide中建议。还有将repeatshuffle合并在一起的功能here