使用TensorFlow数据集进行批处理,重复和随机播放有什么作用?

时间:2018-11-28 07:47:03

标签: tensorflow dataset

我目前正在学习TensorFlow,但是我在这段代码中遇到了困惑:

dataset = dataset.shuffle(buffer_size = 10 * batch_size) 
dataset = dataset.repeat(num_epochs).batch(batch_size)
return dataset.make_one_shot_iterator().get_next()

我首先知道数据集将保存所有数据,但是shuffle(),repeat()和batch()对数据集有什么作用?请给我一个例子的解释

3 个答案:

答案 0 :(得分:6)

想象一下,您有一个数据集:[1, 2, 3, 4, 5, 6],然后:

ds.shuffle()的工作方式

dataset.shuffle(buffer_size=3)将分配一个大小为3的缓冲区以挑选随机条目。该缓冲区将连接到源数据集。 我们可以这样成像:

Random buffer
   |
   |   Source dataset where all other elements live
   |         |
   ↓         ↓
[1,2,3] <= [4,5,6]

让我们假设条目2是从随机缓冲区中提取的。可用空间由源缓冲区中的下一个元素填充,即4

2 <= [1,3,4] <= [5,6]

我们继续阅读,直到什么都没剩下:

1 <= [3,4,5] <= [6]
5 <= [3,4,6] <= []
3 <= [4,6]   <= []
6 <= [4]      <= []
4 <= []      <= []

ds.repeat()的工作方式

一旦从数据集中读取了所有条目,并且您尝试读取下一个元素,则数据集将引发错误。 这就是ds.repeat()发挥作用的地方。它将重新初始化数据集,使其再次像这样:

[1,2,3] <= [4,5,6]

ds.batch()将产生什么

ds.batch()将首先获取batch_size条目,并从中进行批量处理。因此,示例数据集的批处理大小为3将产生两个批处理记录:

[2,1,5]
[3,6,4]

由于在批处理之前有一个ds.repeat(),所以数据的生成将继续。但是,由于ds.random(),元素的顺序将有所不同。应该考虑的是,由于随机缓冲区的大小,6永远不会出现在第一批中。

答案 1 :(得分:0)

tf.Dataset中的以下方法:

  1. repeat( count=0 )该方法重复数据集count次数。
  2. shuffle( buffer_size, seed=None, reshuffle_each_iteration=None)该方法可对数据集中的样本进行混洗。 buffer_size是被随机化并返回为tf.Dataset的样本数。
  3. batch(batch_size,drop_remainder=False)创建数据集的批次,其批次大小指定为batch_size,这也是批次的长度。

答案 2 :(得分:0)

一个显示历元循环的示例。运行此脚本后,请注意

  • dataset_gen1-随机操作会产生更多随机输出(这在运行机器学习实验时可能会更有用
  • dataset_gen2-缺少随机操作会按顺序生成元素

此脚本中的其他添加内容

  • tf.data.experimental.sample_from_datasets-用于合并两个数据集。请注意,这种情况下的随机播放操作将创建一个缓冲区,该缓冲区从两个数据集中均等地采样。
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # to avoid all those prints
os.environ["TF_GPU_THREAD_MODE"] = "gpu_private" # to avoid large "Kernel Launch Time"

import tensorflow as tf
if len(tf.config.list_physical_devices('GPU')):
    tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True)

class Augmentations:

    def __init__(self):
        pass

    @tf.function
    def filter_even(self, x):
        if x % 2 == 0:
            return False
        else:
            return True

class Dataset:

    def __init__(self, aug, range_min=0, range_max=100):
        self.range_min = range_min
        self.range_max = range_max
        self.aug = aug

    def generator(self):
        dataset = tf.data.Dataset.from_generator(self._generator
                        , output_types=(tf.float32), args=())

        dataset = dataset.filter(self.aug.filter_even)

        return dataset
    
    def _generator(self):
        for item in range(self.range_min, self.range_max):
            yield(item)

# Can be used when you have multiple datasets that you wish to combine
class ZipDataset:

    def __init__(self, datasets):
        self.datasets = datasets
        self.datasets_generators = []
    
    def generator(self):
        for dataset in self.datasets:
            self.datasets_generators.append(dataset.generator())
        return tf.data.experimental.sample_from_datasets(self.datasets_generators)

if __name__ == "__main__":
    aug = Augmentations()
    dataset1 = Dataset(aug, 0, 100)
    dataset2 = Dataset(aug, 100, 200)
    dataset = ZipDataset([dataset1, dataset2])

    epochs = 2
    shuffle_buffer = 10
    batch_size = 4
    prefetch_buffer = 5

    dataset_gen1 = dataset.generator().shuffle(shuffle_buffer).batch(batch_size).prefetch(prefetch_buffer)
    # dataset_gen2 = dataset.generator().batch(batch_size).prefetch(prefetch_buffer) # this will output odd elements in sequence 

    for epoch in range(epochs):
        print ('\n ------------------ Epoch: {} ------------------'.format(epoch))
        for X in dataset_gen1.repeat(1): # adding .repeat() in the loop allows you to easily control the end of the loop
            print (X)
        
        # Do some stuff at end of loop