更改batch(),shuffle()和repeat()的顺序时的输出差异

时间:2018-04-19 08:10:21

标签: tensorflow tensorflow-datasets

我已经创建了张量流数据集,使其可重复,将其重新排列,将其分成批次,并构建了一个迭代器来获取下一批。但是当我这样做时,有时候元素是重复的(在批次内和批次之间),特别是对于小数据集。为什么呢?

3 个答案:

答案 0 :(得分:5)

与您自己的回答中所说的不同,不,改组然后重复不会解决您的问题

问题的关键来源批处理,然后是随机播放/重复。这样,批次中的项目将始终从输入数据集中的连续样本中获取。 批处理应该是您在输入管道中执行的最后一项操作

稍微扩展问题。

现在,在您改组,重复和批处理的顺序上的差异,但这不是您的想法。引自input pipeline performance guide

  

如果在随机播放之前应用重复转换   转换,然后时代边界模糊。那是,   某些元素可以在其他元素出现之前重复出现   一旦。另一方面,如果应用了混洗变换   在重复转换之前,性能可能会减慢   每个纪元的开始与内部的初始化有关   洗牌转型的状态。换句话说,前者   (在shuffle之前重复)提供更好的性能,而后者   (重复之前的洗牌)提供更强的排序保证。

重新盖上

  • 重复,然后洗牌:你失去了在一个时期内处理所有样品的保证。
  • 随机播放,然后重复:保证在下一次重复开始之前处理所有样本,但是你的性能会稍微下降。

无论您选择哪种,在批处理之前执行

答案 1 :(得分:2)

你必须先洗牌,然后重复!

如以下两个代码所示,洗牌和重复的顺序。

最差的订购:

import tensorflow as tf

ds = tf.data.Dataset.range(10)
ds = ds.batch(2)
ds = ds.repeat()
ds = ds.shuffle(100000)
iterator   = ds.make_one_shot_iterator()
next_batch = iterator.get_next()

with tf.Session() as sess:
    for i in range(15):
        if i % (10//2) == 0:
            print("------------")
        print("{:02d}:".format(i), next_batch.eval())

输出:

------------
00: [6 7]
01: [2 3]
02: [6 7]
03: [0 1]
04: [8 9]
------------
05: [6 7]
06: [4 5]
07: [6 7]
08: [4 5]
09: [0 1]
------------
10: [2 3]
11: [0 1]
12: [0 1]
13: [2 3]
14: [4 5]

错误订购:

import tensorflow as tf

ds = tf.data.Dataset.range(10)
ds = ds.batch(2)
ds = ds.shuffle(100000)
ds = ds.repeat()
iterator   = ds.make_one_shot_iterator()
next_batch = iterator.get_next()

with tf.Session() as sess:
    for i in range(15):
        if i % (10//2) == 0:
            print("------------")
        print("{:02d}:".format(i), next_batch.eval())

输出:

------------
00: [4 5]
01: [6 7]
02: [8 9]
03: [0 1]
04: [2 3]
------------
05: [0 1]
06: [4 5]
07: [8 9]
08: [2 3]
09: [6 7]
------------
10: [0 1]
11: [4 5]
12: [8 9]
13: [2 3]
14: [6 7]

最佳订购:

受到GPhilo答案的启发,批处理的顺序也很重要。对于每个时期的批次不同,必须首先进行洗牌,然后重复,最后批量进行。从输出中可以看出,所有批次都是独一无二的,与其他批次不同。

import tensorflow as tf

ds = tf.data.Dataset.range(10)

ds = ds.shuffle(100000)
ds = ds.repeat()
ds = ds.batch(2)

iterator   = ds.make_one_shot_iterator()
next_batch = iterator.get_next()

with tf.Session() as sess:
    for i in range(15):
        if i % (10//2) == 0:
            print("------------")
        print("{:02d}:".format(i), next_batch.eval())

输出:

------------
00: [2 5]
01: [1 8]
02: [9 6]
03: [3 4]
04: [7 0]
------------
05: [4 3]
06: [0 2]
07: [1 9]
08: [6 5]
09: [8 7]
------------
10: [7 3]
11: [5 9]
12: [4 1]
13: [8 6]
14: [0 2]

答案 2 :(得分:1)

例如,如果您想要与Keras的.fit()函数相同的行为,则可以使用:

dataset = dataset.shuffle(10000, reshuffle_each_iteration=True)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.repeat(EPOCHS)

这将以与.fit(epochs=EPOCHS, batch_size=BATCH_SIZE, shuffle=True)相同的方式遍历数据集。一个简单的示例(渴望执行仅出于可读性,在图形模式下的行为相同):

import numpy as np
import tensorflow as tf
tf.enable_eager_execution()

NUM_SAMPLES = 7
BATCH_SIZE = 3
EPOCHS = 2

# Create the dataset
x = np.array([[2 * i, 2 * i + 1] for i in range(NUM_SAMPLES)])
dataset = tf.data.Dataset.from_tensor_slices(x)

# Shuffle, batch and repeat the dataset
dataset = dataset.shuffle(10000, reshuffle_each_iteration=True)
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.repeat(EPOCHS)

# Iterate through the dataset
iterator = dataset.make_one_shot_iterator()
for batch in dataset:
    print(batch.numpy(), end='\n\n')

打印

[[ 8  9]
 [12 13]
 [10 11]]

[[0 1]
 [2 3]
 [4 5]]

[[6 7]]

[[ 4  5]
 [10 11]
 [12 13]]

[[6 7]
 [0 1]
 [2 3]]

[[8 9]]

您可以看到,即使.batch()之后 .shuffle()被称为,批次在每个时期仍然是不同的。这就是为什么我们需要使用reshuffle_each_iteration=True。如果我们不希望在每次迭代中都进行改组,那么在每个时期我们将获得相同的批次:

[[12 13]
 [ 4  5]
 [10 11]]

[[6 7]
 [8 9]
 [0 1]]

[[2 3]]

[[12 13]
 [ 4  5]
 [10 11]]

[[6 7]
 [8 9]
 [0 1]]

[[2 3]]

这在训练小型数据集时可能是有害的。