如何正确使用Tensorflow数据集进行批处理?

时间:2018-10-19 10:52:36

标签: python tensorflow tensorflow-datasets

我是Tensorflow和深度学习的新手,并且在Dataset类中挣扎。我尝试了很多事情,但找不到一个好的解决方案。

我在想什么

我有大量图像(500k +)用于训练DNN。这是一个降噪自动编码器,因此每个图像都有一对。我正在使用TF的数据集类来管理数据,但我认为我确实使用得很差。

这是我如何在数据集中加载文件名:

class Data:
def __init__(self, in_path, out_path):
    self.nb_images = 512
    self.test_ratio = 0.2
    self.batch_size = 8

    # load filenames in input and outputs
    inputs, outputs, self.nb_images = self._load_data_pair_paths(in_path, out_path, self.nb_images)

    self.size_training = self.nb_images - int(self.nb_images * self.test_ratio)
    self.size_test = int(self.nb_images * self.test_ratio)

    # split arrays in training / validation
    test_data_in, training_data_in = self._split_test_data(inputs, self.test_ratio)
    test_data_out, training_data_out = self._split_test_data(outputs, self.test_ratio)

    # transform array to tf.data.Dataset
    self.train_dataset = tf.data.Dataset.from_tensor_slices((training_data_in, training_data_out))
    self.test_dataset = tf.data.Dataset.from_tensor_slices((test_data_in, test_data_out))

我有一个函数在准备数据集的每个纪元处调用。它会改组文件名,并将文件名转换为图像和批处理数据。

def get_batched_data(self, seed, batch_size):
    nb_batch = int(self.size_training / batch_size)

    def img_to_tensor(path_in, path_out):
        img_string_in = tf.read_file(path_in)
        img_string_out = tf.read_file(path_out)
        im_in = tf.image.decode_jpeg(img_string_in, channels=1)
        im_out = tf.image.decode_jpeg(img_string_out, channels=1)
        return im_in, im_out

    t_datas = self.train_dataset.shuffle(self.size_training, seed=seed)
    t_datas = t_datas.map(img_to_tensor)
    t_datas = t_datas.batch(batch_size)
    return t_datas

现在,在训练期间,我们在每个时期调用get_batched_data函数,进行迭代,然后针对每个批次运行该迭代器,然后将数组提供给优化器操作。

for epoch in range(nb_epoch):
    sess_iter_in = tf.Session()
    sess_iter_out = tf.Session()

    batched_train = data.get_batched_data(epoch)
    iterator_train = batched_train.make_one_shot_iterator()
    in_data, out_data = iterator_train.get_next()

    total_batch = int(data.size_training / batch_size)
    for batch in range(total_batch):
        print(f"{batch + 1} / {total_batch}")
        in_images = sess_iter_in.run(in_data).reshape((-1, 64, 64, 1))
        out_images = sess_iter_out.run(out_data).reshape((-1, 64, 64, 1))
        sess.run(optimizer, feed_dict={inputs: in_images,
                                       outputs: out_images})

我需要什么?

我需要有一个仅加载当前批次图像的管道(否则它将不适合存储在内存中),并且我想针对每个时期以不同的方式重新整理数据集。

疑问和问题

第一个问题,我是否很好地使用了Dataset类?我在互联网上看到了截然不同的事物,例如在this博客文章中,数据集与占位符一起使用,并在学习过程中与数据一起馈入。似乎很奇怪,因为数据全部在一个数组中,因此被加载到内存中。我看不出在这种情况下使用tf.data.dataset的意义。

我通过在数据集上使用repeat(epoch)(例如this)找到了解决方案,但是在这种情况下,每个时期的混洗都不会有所不同。

实现的第二个问题是在某些情况下我有一个OutOfRangeError。使用少量数据(例如512),可以正常工作,但是使用大量数据时,会发生错误。我认为这是由于舍入不佳导致批处理数量计算不正确,或者最后一个批处理的数据量较小时发生的,但是它发生在115个批处理中的第32个...有什么办法可以知道在batch(n)调用数据集之后创建的批处理数量是多少?

很抱歉遇到这个问题,但是我已经为此苦苦挣扎了几天。

1 个答案:

答案 0 :(得分:2)

据我所知,Official Performance Guideline是制作输入管道的最佳教材。

  

我想针对每个时期以不同的方式对数据集进行洗牌。

使用shuffle()和repeat(),您可以为每个时期获得不同的随机播放模式。您可以使用以下代码进行确认

dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4])
dataset = dataset.shuffle(4)
dataset = dataset.repeat(3)

iterator = dataset.make_one_shot_iterator()
x = iterator.get_next()

with tf.Session() as sess:
    for i in range(10):
        print(sess.run(x))

您也可以使用上面官方页面中提到的tf.contrib.data.shuffle_and_repeat。

除了创建数据管道之外,您的代码中还有一些问题。您将图的构造与图的执行混淆了。您将重复创建数据输入管道,因此有许多冗余输入管道,其数量与时期一样多。您可以通过Tensorboard观察冗余管道。

您应该将图形构造代码放在循环之外,作为以下代码(伪代码)

batched_train = data.get_batched_data()
iterator = batched_train.make_initializable_iterator()
in_data, out_data = iterator_train.get_next()

for epoch in range(nb_epoch):
    # reset iterator's state
    sess.run(iterator.initializer)

    try:
        while True:
            in_images = sess.run(in_data).reshape((-1, 64, 64, 1))
            out_images = sess.run(out_data).reshape((-1, 64, 64, 1))
            sess.run(optimizer, feed_dict={inputs: in_images,
                                           outputs: out_images})
    except tf.errors.OutOfRangeError:
        pass

此外,还有一些不重要的低效率代码。您使用from_tensor_slices()加载了文件路径列表,因此该列表已嵌入到图形中。 (有关详细信息,请参见https://www.tensorflow.org/guide/datasets#consuming_numpy_arrays

最好使用预取,并通过组合图形来减少sess.run调用。