为什么在同一个 tf.data.Dataset 上的迭代每次迭代都会给出不同的数据?

时间:2021-03-30 16:35:49

标签: python tensorflow tensorflow-datasets

我正在尝试了解 tf.data.Dataset 的工作原理。

它在文档上说 take 返回一个数据集,其中包含来自该数据集的一定数量的元素。然后您可以迭代单个样本(在本例中为批次):

import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds

# Construct a tf.data.Dataset
ds = tfds.load('mnist', split='train', shuffle_files=True)

# Build your input pipeline
ds = ds.shuffle(1024).batch(32).prefetch(tf.data.experimental.AUTOTUNE)

single_batch_dataset = ds.take(1)

for example in single_batch_dataset:
  image, label = example["image"], example["label"]
  print(label)
# ...

输出:

tf.Tensor([2 0 6 6 8 8 6 0 3 4 8 7 5 2 5 7 8 7 1 1 1 8 6 4 0 4 3 2 4 2 1 9], shape=(32,), dtype=int64)

然而,再次迭代,给出不同的标签:(上次代码的延续)

for example in single_batch_dataset:
  image, label = example["image"], example["label"]
  print(label)

for example in single_batch_dataset:
  image, label = example["image"], example["label"]
  print(label)

输出:

tf.Tensor([7 3 5 6 3 1 7 9 6 1 9 3 9 8 6 7 7 1 9 7 5 2 0 7 8 1 7 8 7 0 5 0], shape=(32,), dtype=int64)
tf.Tensor([1 3 6 1 8 8 0 4 1 3 2 9 5 3 8 7 4 2 1 8 1 0 8 5 4 5 6 7 3 4 4 1], shape=(32,), dtype=int64)

鉴于数据集相同,标签不应该相同吗?

1 个答案:

答案 0 :(得分:0)

这是因为数据文件打乱了,数据集打乱了dataset.shuffle()

使用 dataset.shuffle(),默认情况下,每次迭代时数据将以不同的方式打乱。

可以删除 shuffle_files=True 并设置参数 reshuffle_each_iteration=False 以防止在不同的迭代中重新洗牌。

.take() 函数并不意味着确定性。它只会按照数据集给出的顺序从数据集中取出 N 个项目。

# Construct a tf.data.Dataset
ds = tfds.load('mnist', split='train', shuffle_files=False)

# Build your input pipeline
ds = ds.shuffle(1024, reshuffle_each_iteration=False).batch(32).prefetch(tf.data.experimental.AUTOTUNE)

single_batch_dataset = ds.take(1)

for example in single_batch_dataset:
    image, label = example["image"], example["label"]
    print(label)
    
for example in single_batch_dataset:
    image, label = example["image"], example["label"]
    print(label)

输出:

tf.Tensor([4 6 8 5 1 4 5 8 1 4 6 6 8 6 6 9 4 2 3 0 5 9 2 1 3 1 8 6 4 4 7 1], shape=(32,), dtype=int64)
tf.Tensor([4 6 8 5 1 4 5 8 1 4 6 6 8 6 6 9 4 2 3 0 5 9 2 1 3 1 8 6 4 4 7 1], shape=(32,), dtype=int64)
相关问题