批量将数据添加到tf.data.Dataset API

时间:2019-08-01 18:25:00

标签: tensorflow tensorflow-datasets

我有一个数据集,我想传递给tf.data.Dataset API进行调度。但是,我需要读入大量数据文件并对其进行预处理,然后再传递给tf.data.Dataset。我不想一次将所有数据读入内存。目前,我一次读取一个文件,对它们进行批处理,然后读取下一个文件。即

import itertools

class DataBatcher:
    def __init__(self, list_of_files):
        self.list_of_files = list_of_files
        self.init_index = 0
        self.x, self.y =  self.loop_files()

    def loop_files(self):
        for one_part in itertools.cycle(self.list_of_files):
             # read in one_part file, which has N examples
             # doing some prepocessing
             # self.x and self.y will be of shape [N, feature_num], [N,] respectively
             yield self.x, self.y

    def gen_batch(self, batch_size):
        # discard remaining data if it't not enough for a batch
        if (self.init_index + batch_size) > self.y.shape[0]:
            self.x, self.y =  self.loop_files()
            self.init_index = 0
        x_batch, y_batch = self.x[self.init_index: self.init_index + batch_size, :], self.y[self.init_index: self.init_index + batch_size]
        self.init_index += batch_size
        return x_batch, y_batch 

data_batcher = DataBatcher(files)
x_sample, y_sample = data_batcher.gen_batch(10)

然后将

x_sampley_sample通过feed_dict传递到张量流模型。但是我想更改为tf.data.Dataset以提高性能。我发现tf.data.Dataset.from_tensor_slices会在分派之前读取整个数据集,而tf.data.Dataset.from_generator会在一个示例中同时读取一个示例,这两个示例都没有高效地使用资源,两者之间有什么办法吗?类似于使用生成器生成数据集的一部分?

1 个答案:

答案 0 :(得分:0)

如果使用数据集API,则不需要您的批处理逻辑。最好是:

  1. 在现有的生成器函数上使用Dataset.from_generator
  2. 然后使用dataset.batch()方法将其批处理为所需的任何大小的批处理。

也许是这样的:

data_batcher = DataBatcher(files)

dataset = tf.data.Dataset.from_generator(data_batcher.loop_files, output_types=(tf.float32, tf.float32))
batched_data = dataset.batch(BATCH_SIZE)