如何在第一个时期正确缓存数据(Tensorflow,数据集)?

时间:2018-05-24 23:03:41

标签: tensorflow tensorflow-datasets

我尝试对cache使用dataset转换。这是我目前的代码(简化):

dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=1)
dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=5000, count=1))
dataset = dataset.map(_parser_a, num_parallel_calls=12)
dataset = dataset.padded_batch(
    20, 
    padded_shapes=padded_shapes,
    padding_values=padding_values
)
dataset = dataset.prefetch(buffer_size=1)
dataset = dataset.cache()

在第一个纪元之后,我收到以下错误消息:

  

调用迭代器没有完全读取我们正在尝试的数据集   缓存。为了避免意外截断序列,   当前[部分缓存]序列将被删除。如果这可能发生   你有一个类似于dataset.cache().take(k).repeat()的序列。   相反,交换订单(即dataset.take(k).cache().repeat()

然后,代码继续并仍然从硬盘驱动器而不是缓存中读取数据。那么,我应该在哪里放置dataset.cache()以避免错误? 感谢。

1 个答案:

答案 0 :(得分:3)

Dataset.cache()转换的实现非常简单:它会在第一次迭代完全时构建一个通过它的元素列表,并返回元素从该列表中随后尝试迭代它。如果第一遍仅对数据执行部分传递,则列表不完整,并且TensorFlow不会尝试使用缓存数据,因为它不知道是否需要其余元素,通常它可能需要重新处理所有前面的元素来计算剩余的元素。

通过修改程序以使用整个数据集,并迭代它直到引发tf.errors.OutOfRangeError,缓存将具有数据集中元素的完整列表,并将在所有后续迭代中使用。