Tensorflow数据集API中的内存管理

时间:2017-07-16 03:35:06

标签: tensorflow tensorflow-datasets

我的训练数据集太大而无法容纳到内存中,因此我的代码一次只能从磁盘读取1,000条记录。现在我想使用Tensorflow的新Dataset API。数据集API是否允许我指定要保留在内存中的记录数,或者Tensorflow是否自动管理内存以便我不必?

3 个答案:

答案 0 :(得分:3)

是。官方指南中的示例(使用TensorFlow输入管道的数据集API,https://www.tensorflow.org/programmers_guide/datasets

filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(...) ## Parsing data with a user specified function
dataset = dataset.shuffle(buffer_size=10000) ## 10000: size of sample/record pool for random selection
dataset = dataset.batch(32) ## 32: number of samples/records per batch (to be read into memory)
dataset = dataset.repeat() ## None: keep repeating

答案 1 :(得分:1)

如果您要通过batch_size指定记录数。在这种情况下,TF将仅从文件中获取batch_size元素。您还可以指定shuffle,这样可以保证内存中的所有时间都处于最大buffer_size个元素。

我在我的tfrecords文件上验证了它。我有100个tfrecords文件,每个文件大约10Gb(这比我的笔记本电脑上的内存更多)。一切正常。

答案 2 :(得分:0)

dataset = dataset.prefetch(buffer_size)我想prefetch会这样做吗?如果buffer_size设置得足够大,那么所有tfrcords都会在mem中保存? Meaning of buffer_size in Dataset.map , Dataset.prefetch and Dataset.shuffle