对整个数据集或每次调用iterator.next()进行一次Tensorflow数据集数据预处理?

时间:2018-02-11 14:28:04

标签: python tensorflow tensorflow-datasets

您好我正在研究tensorflow中的数据集API,我对datt.map()函数有一个问题,该函数执行数据预处理。

file_name = ["image1.jpg", "image2.jpg", ......]
im_dataset = tf.data.Dataset.from_tensor_slices(file_names)
im_dataset = im_dataset.map(lambda image:tuple(tf.py_func(image_parser(), [image], [tf.float32, tf.float32, tf.float32])))
im_dataset = im_dataset.batch(batch_size)
iterator = im_dataset.make_initializable_iterator()

数据集接收图像名称并将其解析为3个张量(关于图像的3个信息)。

如果我的训练文件夹中有大量图像,预处理它们需要很长时间。 我的问题是,由于数据集API据说是为高效的输入管道而设计的,因此在我将它们提供给我的工作人员(比如说GPU)之前对整个数据集进行预处理,或者每次我只预处理一批图像调用iterator.get_next()?

1 个答案:

答案 0 :(得分:7)

如果预处理管道很长且输出很小,则处理后的数据应该适合内存。如果是这种情况,您可以使用tf.data.Dataset.cache将已处理的数据缓存在内存或文件中。

来自官方performance guide

  

tf.data.Dataset.cache转换可以在内存或本地存储中缓存数据集。如果传递给映射转换的用户定义函数很昂贵,只要结果数据集仍然适合内存或本地存储,就可以在映射转换后应用缓存转换。如果用户定义的函数增加了存储数据集超出缓存容量所需的空间,请考虑在训练作业之前预处理数据以减少资源使用。

在内存中使用缓存的示例

以下是每个预处理需要花费大量时间(0.5秒)的示例。数据集上的第二个时期将比第一个时期快得多

def my_fn(x):
    time.sleep(0.5)
    return x

def parse_fn(x):
    return tf.py_func(my_fn, [x], tf.int64)

dataset = tf.data.Dataset.range(5)
dataset = dataset.map(parse_fn)
dataset = dataset.cache()    # cache the processed dataset, so every input will be processed once
dataset = dataset.repeat(2)  # repeat for multiple epochs

res = dataset.make_one_shot_iterator().get_next()

with tf.Session() as sess:
    for i in range(10):
        # First 5 iterations will take 0.5s each, last 5 will not
        print(sess.run(res))

缓存到文件

如果要将缓存数据写入文件,可以为cache()提供参数:

dataset = dataset.cache('/tmp/cache')  # will write cached data to a file

这将允许您只处理数据集一次,并对数据运行多个实验,而无需再次重新处理。

警告:缓存到文件时,您必须小心。如果您更改数据但保留/tmp/cache.*文件,它仍会读取缓存的旧数据。例如,如果我们使用上面的数据并将数据范围更改为[10, 15],我们仍会在[0, 5]中获取数据:

dataset = tf.data.Dataset.range(10, 15)
dataset = dataset.map(parse_fn)
dataset = dataset.cache('/tmp/cache')
dataset = dataset.repeat(2)  # repeat for multiple epochs

res = dataset.make_one_shot_iterator().get_next()

with tf.Session() as sess:
    for i in range(10):
        print(sess.run(res))  # will still be in [0, 5]...
  每当您要缓存的数据发生变化时,

始终删除缓存的文件。

可能出现的另一个问题是,如果在缓存所有数据之前中断脚本。您将收到如下错误:

  

AlreadyExistsError(参见上面的回溯):似乎有一个并发缓存迭代器正在运行 - 缓存锁定文件已经存在(' /tmp/cache.lockfile')。如果您确定没有其他正在运行的TF计算正在使用此缓存前缀,请删除锁定文件并重新初始化迭代器。

确保您处理整个数据集以拥有整个缓存文件。