我目前正在学习TensorFlow,但是我在这段代码中遇到了困惑:
dataset = dataset.shuffle(buffer_size = 10 * batch_size)
dataset = dataset.repeat(num_epochs).batch(batch_size)
return dataset.make_one_shot_iterator().get_next()
我首先知道数据集将保存所有数据,但是shuffle(),repeat()和batch()对数据集有什么作用?请给我一个例子的解释
答案 0 :(得分:6)
想象一下,您有一个数据集:[1, 2, 3, 4, 5, 6]
,然后:
ds.shuffle()的工作方式
dataset.shuffle(buffer_size=3)
将分配一个大小为3的缓冲区以挑选随机条目。该缓冲区将连接到源数据集。
我们可以这样成像:
Random buffer
|
| Source dataset where all other elements live
| |
↓ ↓
[1,2,3] <= [4,5,6]
让我们假设条目2
是从随机缓冲区中提取的。可用空间由源缓冲区中的下一个元素填充,即4
:
2 <= [1,3,4] <= [5,6]
我们继续阅读,直到什么都没剩下:
1 <= [3,4,5] <= [6]
5 <= [3,4,6] <= []
3 <= [4,6] <= []
6 <= [4] <= []
4 <= [] <= []
ds.repeat()的工作方式
一旦从数据集中读取了所有条目,并且您尝试读取下一个元素,则数据集将引发错误。
这就是ds.repeat()
发挥作用的地方。它将重新初始化数据集,使其再次像这样:
[1,2,3] <= [4,5,6]
ds.batch()将产生什么
ds.batch()
将首先获取batch_size
条目,并从中进行批量处理。因此,示例数据集的批处理大小为3将产生两个批处理记录:
[2,1,5]
[3,6,4]
由于在批处理之前有一个ds.repeat()
,所以数据的生成将继续。但是,由于ds.random()
,元素的顺序将有所不同。应该考虑的是,由于随机缓冲区的大小,6
永远不会出现在第一批中。
答案 1 :(得分:0)
tf.Dataset中的以下方法:
repeat( count=0 )
该方法重复数据集count
次数。shuffle( buffer_size, seed=None, reshuffle_each_iteration=None)
该方法可对数据集中的样本进行混洗。 buffer_size
是被随机化并返回为tf.Dataset
的样本数。batch(batch_size,drop_remainder=False)
创建数据集的批次,其批次大小指定为batch_size
,这也是批次的长度。答案 2 :(得分:0)
一个显示历元循环的示例。运行此脚本后,请注意
dataset_gen1
-随机操作会产生更多随机输出(这在运行机器学习实验时可能会更有用)dataset_gen2
-缺少随机操作会按顺序生成元素此脚本中的其他添加内容
tf.data.experimental.sample_from_datasets
-用于合并两个数据集。请注意,这种情况下的随机播放操作将创建一个缓冲区,该缓冲区从两个数据集中均等地采样。import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # to avoid all those prints
os.environ["TF_GPU_THREAD_MODE"] = "gpu_private" # to avoid large "Kernel Launch Time"
import tensorflow as tf
if len(tf.config.list_physical_devices('GPU')):
tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True)
class Augmentations:
def __init__(self):
pass
@tf.function
def filter_even(self, x):
if x % 2 == 0:
return False
else:
return True
class Dataset:
def __init__(self, aug, range_min=0, range_max=100):
self.range_min = range_min
self.range_max = range_max
self.aug = aug
def generator(self):
dataset = tf.data.Dataset.from_generator(self._generator
, output_types=(tf.float32), args=())
dataset = dataset.filter(self.aug.filter_even)
return dataset
def _generator(self):
for item in range(self.range_min, self.range_max):
yield(item)
# Can be used when you have multiple datasets that you wish to combine
class ZipDataset:
def __init__(self, datasets):
self.datasets = datasets
self.datasets_generators = []
def generator(self):
for dataset in self.datasets:
self.datasets_generators.append(dataset.generator())
return tf.data.experimental.sample_from_datasets(self.datasets_generators)
if __name__ == "__main__":
aug = Augmentations()
dataset1 = Dataset(aug, 0, 100)
dataset2 = Dataset(aug, 100, 200)
dataset = ZipDataset([dataset1, dataset2])
epochs = 2
shuffle_buffer = 10
batch_size = 4
prefetch_buffer = 5
dataset_gen1 = dataset.generator().shuffle(shuffle_buffer).batch(batch_size).prefetch(prefetch_buffer)
# dataset_gen2 = dataset.generator().batch(batch_size).prefetch(prefetch_buffer) # this will output odd elements in sequence
for epoch in range(epochs):
print ('\n ------------------ Epoch: {} ------------------'.format(epoch))
for X in dataset_gen1.repeat(1): # adding .repeat() in the loop allows you to easily control the end of the loop
print (X)
# Do some stuff at end of loop