张量流数据集随后批处理或批处理随后随机播放

时间:2018-05-20 16:52:05

标签: tensorflow tensorflow-datasets

我最近开始学习张量流。

我不确定是否存在差异

x = np.array([[1],[2],[3],[4],[5]])
dataset = tf.data.Dataset.from_tensor_slices(x)
ds.shuffle(buffer_size=4)
ds.batch(4)

x = np.array([[1],[2],[3],[4],[5]])
dataset = tf.data.Dataset.from_tensor_slices(x)
ds.batch(4)
ds.shuffle(buffer_size=4)

另外,我不知道为什么我不能使用

dataset = dataset.shuffle_batch(buffer_size=2,batch_size=BATCH_SIZE)

因为它给出了错误

dataset = dataset.shuffle_batch(buffer_size=2,batch_size=BATCH_SIZE)
AttributeError: 'TensorSliceDataset' object has no attribute 'shuffle_batch'

谢谢!

2 个答案:

答案 0 :(得分:3)

TL; DR:是的,有区别。几乎总是,您需要在 Dataset.shuffle()之前致电Dataset.batch() tf.data.Dataset类上没有for (const elem of someArray) 方法,您必须分别调用这两个方法来对数据集进行随机播放和批处理。

shuffle_batch()的转换应用于它们被调用的相同序列。 tf.data.Dataset将其输入的连续元素组合到输出中的单个批处理元素中。 通过考虑以下两个数据集,我们可以看到操作顺序的影响:

Dataset.batch()

在第一个版本(shuffle之前的批处理)中,每个批次的元素是来自输入的3个连续元素;而在第二个版本(批次之前随机播放)中,它们是从输入中随机抽样的。通常,当通过(某些变体)小批量stochastic gradient descent进行培训时,应从总输入中尽可能均匀地对每批次的元素进行采样。否则,网络可能会过度适应输入数据中的任何结构,并且生成的网络将无法实现高精度。

答案 1 :(得分:1)

完全同意@mrry,但是在一种情况下,您可能想在混洗前 进行批处理。假设您正在处理一些文本数据,这些数据将被馈送到RNN中。在这里,每个句子被视为一个序列,而一批将包含多个序列。由于句子的长度是可变的,因此我们需要批量将句子填充到统一的长度。一种有效的方法是通过批处理将相似长度的句子分组在一起,然后进行混排。否则,我们可能会结束充满<pad>令牌的批次。