如何仅从tfrecord文件加载数据一次

时间:2018-08-27 06:49:52

标签: python tensorflow tensorflow-datasets

我正在玩tf.data.Dataset和tfrecord文件来找出我的张量流代码中的瓶颈。在下面的玩具脚本中,在第二个时期中,数据被删除并且无法从磁盘读取。

import os
import tensorflow as tf
import numpy as np
tf.enable_eager_execution()

xs = [[1,2,3], [2,4,45,5],[5,5,56,2]]
names = ['a','b','c']
os.mkdir('foo')
for name, x in zip(names, xs):
    features = tf.train.Features(feature={
        'x': tf.train.Feature(int64_list=tf.train.Int64List(value=x))})
    ex = tf.train.Example(features=features)
    writer = tf.python_io.TFRecordWriter(f'foo/{name}')
    writer.write(ex.SerializeToString())
    writer.close()

def _parse_function(example_proto):
    features={'x':tf.VarLenFeature(tf.int64),
            }
    parsed_features = tf.parse_single_example(example_proto, features)
    return parsed_features

dataset = tf.data.TFRecordDataset(['foo/'+ f for f in os.listdir('foo')])
dataset.shuffle(buffer_size=3)
dataset.map(_parse_function)
dataset.cache()
dataset.repeat(3)
for i in range(3):
    if i == 1:
        for f in os.listdir('foo'):
            os.remove('foo/'+f)
    for d in dataset:
        print(d)

在第二个时期,我得到了这个错误:

tensorflow.python.framework.errors_impl.NotFoundError:
foo/c; No such file or directory [Op:IteratorGetNextSync].

这意味着tf.data.TFRecordDataset()甚至在每个tf.data.TFRecordDataset.cache()时都从磁盘读取数据。

我以为整个数据集都将缓存在内存中,并且文件io可能不会在第一个时期之后发生。显然不是这种情况。 cache()方法有什么作用?另外,有什么方法在训练开始时只加载一次整个tfrecord文件吗?

0 个答案:

没有答案