我正在玩tf.data.Dataset
和tfrecord文件来找出我的张量流代码中的瓶颈。在下面的玩具脚本中,在第二个时期中,数据被删除并且无法从磁盘读取。
import os
import tensorflow as tf
import numpy as np
tf.enable_eager_execution()
xs = [[1,2,3], [2,4,45,5],[5,5,56,2]]
names = ['a','b','c']
os.mkdir('foo')
for name, x in zip(names, xs):
features = tf.train.Features(feature={
'x': tf.train.Feature(int64_list=tf.train.Int64List(value=x))})
ex = tf.train.Example(features=features)
writer = tf.python_io.TFRecordWriter(f'foo/{name}')
writer.write(ex.SerializeToString())
writer.close()
def _parse_function(example_proto):
features={'x':tf.VarLenFeature(tf.int64),
}
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features
dataset = tf.data.TFRecordDataset(['foo/'+ f for f in os.listdir('foo')])
dataset.shuffle(buffer_size=3)
dataset.map(_parse_function)
dataset.cache()
dataset.repeat(3)
for i in range(3):
if i == 1:
for f in os.listdir('foo'):
os.remove('foo/'+f)
for d in dataset:
print(d)
在第二个时期,我得到了这个错误:
tensorflow.python.framework.errors_impl.NotFoundError:
foo/c; No such file or directory [Op:IteratorGetNextSync].
这意味着tf.data.TFRecordDataset()
甚至在每个tf.data.TFRecordDataset.cache()
时都从磁盘读取数据。
我以为整个数据集都将缓存在内存中,并且文件io可能不会在第一个时期之后发生。显然不是这种情况。 cache()
方法有什么作用?另外,有什么方法在训练开始时只加载一次整个tfrecord文件吗?