Pytorch何时会发生数据加载器改组?

时间:2020-05-10 21:24:50

标签: python machine-learning pytorch shuffle dataloader

我一直在为pytorch数据加载器使用shuffle选项很多次。但是我想知道这种混洗何时发生以及是否在迭代过程中动态执行。以下面的代码为例:

namesDataset = NamesDataset()
namesTrainLoader = DataLoader(namesDataset, batch_size=16, shuffle=True)
for batch_data in namesTrainLoader:
    print(batch_data)

当我们定义“ namesTrainLoader”时,这是否意味着改组已完成,并且以下迭代将基于固定的数据顺序?定义了namesTrainLoader之后,for循环中是否会有任何随机性?

我试图用一些特殊值替换“ batch_data”的一半:

for batch_data in namesTrainLoader:
    batch_data[:8] = special_val
    pre = model(batch_data)

让我们说将有无限数量的时代,“模型”最终会看到“ namesTrainLoader”中的所有数据吗?还是“ namesTrainLoader”的一半数据实际上丢失了“模型”?

2 个答案:

答案 0 :(得分:4)

改组在创建迭代器时发生。对于for循环,这种情况恰好在for循环开始之前。

您可以使用以下方法手动创建迭代器:

# Iterator gets created, the data has been shuffled at this point.
data_iterator = iter(namesTrainLoader)

默认情况下,如果您设置了shuffle=True(未提供自己的采样器),则数据加载器将使用torch.utils.data.RandomSampler。它的实现非常简单,您可以通过查看RandomSampler.__iter__方法来查看创建迭代器时数据在何处进行混编:

def __iter__(self):
    n = len(self.data_source)
    if self.replacement:
        return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
    return iter(torch.randperm(n).tolist())

return语句是重排的重要部分。它只是创建索引的随机排列。

这意味着每次完全消耗迭代器时,您将看到整个数据集,而每次的顺序都不同。因此,不会丢失任何数据(不包括drop_last=True的情况),您的模型将在每个时期看到所有数据。

答案 1 :(得分:1)

您可以检查PyTorch对torch.utils.data.DataLoader here的实现。

如果您指定shuffle=True,将使用torch.utils.data.RandomSampler(否则将使用SequentialSampler)。

创建DataLoader的实例时,不会洗牌,只会实例化对象的必要私有成员和诸如事物之类的其他设置。

当您在迭代过程中发出特殊的__iter__方法时(如您的情况),将返回一个名为_SingleProcessDataLoader(self)的特殊对象,该对象是数据生成器(可能是批处理,混洗等),假设您不使用多处理)。

要找到所有与私有方法和助手相关的方法都有些麻烦,但是它的基本作用是使用基础的sampler获取用于从{{1 }}。

运行采样器直到耗尽,然后重复该过程(通常是一个时期)。

namesTrainLoader之后的for循环中是否会有任何随机性 被定义了?

在每个循环/纪元torch.utils.data.Dataset开始,索引会随机播放,因此,是的,它将在每个纪元(调用RandomSampler和新{{ 1}}返回),可以无限期完成。

[...]“模型”最终会看到“ namesTrainLoader”中的所有数据吗?

是的,它很可能最终会看到所有数据点