我的训练数据集太大而无法容纳到内存中,因此我的代码一次只能从磁盘读取1,000条记录。现在我想使用Tensorflow的新Dataset API。数据集API是否允许我指定要保留在内存中的记录数,或者Tensorflow是否自动管理内存以便我不必?
答案 0 :(得分:3)
是。官方指南中的示例(使用TensorFlow输入管道的数据集API,https://www.tensorflow.org/programmers_guide/datasets)
filenames = ["/var/data/file1.tfrecord", "/var/data/file2.tfrecord"]
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(...) ## Parsing data with a user specified function
dataset = dataset.shuffle(buffer_size=10000) ## 10000: size of sample/record pool for random selection
dataset = dataset.batch(32) ## 32: number of samples/records per batch (to be read into memory)
dataset = dataset.repeat() ## None: keep repeating
答案 1 :(得分:1)
如果您要通过batch_size指定记录数。在这种情况下,TF将仅从文件中获取batch_size元素。您还可以指定shuffle,这样可以保证内存中的所有时间都处于最大buffer_size
个元素。
我在我的tfrecords文件上验证了它。我有100个tfrecords文件,每个文件大约10Gb(这比我的笔记本电脑上的内存更多)。一切正常。
答案 2 :(得分:0)
dataset = dataset.prefetch(buffer_size)我想prefetch会这样做吗?如果buffer_size设置得足够大,那么所有tfrcords都会在mem中保存? Meaning of buffer_size in Dataset.map , Dataset.prefetch and Dataset.shuffle