如何使用PyTorch的DataLoader确保批次包含所有工人的样品?

时间:2019-08-30 15:08:25

标签: pytorch dataloader

我想知道如何在PyTorch中使用torch.utils.data.DataLoader,尤其是在多工情况下。

我发现DataLoader的一批输出始终来自单个工作人员。 我希望在DataLoader中有一个队列,该队列存储来自所有工作程序的数据,并且DataLoader在队列中将它们混洗以输出随机批处理数据。我认为这是Tensorflow中tf.data.Dataset中的方式。 我们可以在PyTorch中实现类似的功能吗?我想通过使用多个工作程序从大型序列化文件(如Tfrecord)中加载数据集。在这种情况下,将源文件分批混合,这意味着混合工作程序的源非常重要。

请参考以下代码:

import random
import time

import torch


class MyDataset(torch.utils.data.Dataset):
    def __len__(self):
        return 50

    def __getitem__(self, idx):
        info = torch.utils.data.get_worker_info()

        time.sleep(random.uniform(0, 1))
        print("[{}]:{}".format(info.id, idx))
        return idx, info.id


if __name__ == '__main__':
    dataset = MyDataset()
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=5, shuffle=False, num_workers=2)
    for batch in dataloader:
        print(batch)

输出:

[0]:0
[1]:5
[0]:1
[1]:6
[0]:2
[0]:3
[1]:7
[0]:4
[tensor([0, 1, 2, 3, 4]), tensor([0, 0, 0, 0, 0])]
[1]:8
[1]:9
[tensor([5, 6, 7, 8, 9]), tensor([1, 1, 1, 1, 1])]
[0]:10
[0]:11
[1]:15
[1]:16
[0]:12
[1]:17
...

在这里,[0, 1, 2, 3, 4]中的[0, 0, 0, 0, 0][tensor([0, 1, 2, 3, 4]), tensor([0, 0, 0, 0, 0])]意味着该批次包含来自工作人员ID 0的第0到第4索引数据。 请注意,shuffle=True不能解决仅更改数据索引的问题。

在这种情况下,我想获得一个类似[tensor([0, 5, 1, 6, 2]), tensor([0, 1, 0, 1, 0])]的批处理。

2 个答案:

答案 0 :(得分:0)

请注意,指定了batch_size的多工人DataLoader将并行装载多个批次,因此,基本上一个批次总是来自一个工人。但是,通过执行以下操作,我已经达到了您所需的目标:

  1. 将批次大小设为1,因此每个工人一次只能生产一个样品

  2. 编写一个遍历DataLoader的后台进程,一次获取1个样本并将其插入队列。这样就可以将样品以不同的顺序排在队列中,而不是使用特定于工人的批次

  3. 具有批处理机制,例如collate_fn,它从队列中获取等于您的批处理大小的样本数量并将其馈送到模型中

如果您想更具体地创建批处理,例如像从特定工作人员中选择特定样本,则可以有多个队列。您的整理程序应进行修改以解决多个队列并从中选择。但是我怀疑是否需要这种特异性。

答案 1 :(得分:0)

我已经实现了一些简单的方法来解决类似的问题,其中我将大型视频文件作为训练数据,每个工作人员负责加载和预处理单个文件,然后从中获取样本。问题是,正如OP所描述的那样,使用Pytorch的默认数据加载机制,每个批次仅包含来自单个视频文件的样本。

首先,让我们回顾一下问题。在此简化的代码示例中,每个工作人员产生一个包含其零索引工作人员ID的张量。对于32名工人和4名工人的批量,我们希望每批包含8个零,8个,8个2和8个三。

from collections import defaultdict

import torch as T
import torch.utils.data as tdata


class Dataset(tdata.IterableDataset):
    def __init__(self, batch_size: int):
        self._bs = batch_size

    def __iter__(self):
        worker_info = tdata.get_worker_info()
        if not worker_info:
            raise NotImplementedError('Not implemented for num_workers=0')
        for _ in range(self._bs):
            yield T.tensor([worker_info.id])


batch_size = 32
num_workers = 4
dataset = Dataset(batch_size)
loader = tdata.DataLoader(dataset,
                          batch_size=batch_size,
                          num_workers=num_workers)


for batch in loader:
    counts = defaultdict(int)
    for n in batch.numpy().flatten():
        counts[n] += 1
    print(dict(counts))

代替代码打印:

{0: 32}
{1: 32}
{2: 32}
{3: 32}

这意味着第一批仅包含来自工作人员0的样本,第二批仅包含来自工作人员1的样本,以此类推。为此,我们将DataLoader的批次大小设置为batch_size // num_workers,并使用DataLoader上的简单包装,以汇集我们每个批次的每个工作人员的样本:

def pooled_batches(loader):
    loader_it = iter(loader)
    while True:
        samples = []
        for _ in range(loader.num_workers):
            try:
                samples.append(next(loader_it))
            except StopIteration:
                pass
        if len(samples) == 0:
            break
        else:
            yield T.cat(samples, dim=0)


batch_size = 32
num_workers = 4
dataset = Dataset(batch_size)
per_worker = batch_size // num_workers
loader = tdata.DataLoader(dataset,
                          batch_size=per_worker,
                          num_workers=num_workers)

for batch in pooled_batches(loader):
    counts = defaultdict(int)
    for n in batch.numpy().flatten():
        counts[n] += 1
    print(dict(counts))

现在可以打印出代码

{0: 8, 1: 8, 2: 8, 3: 8}
{0: 8, 1: 8, 2: 8, 3: 8}
{0: 8, 1: 8, 2: 8, 3: 8}
{0: 8, 1: 8, 2: 8, 3: 8}

符合预期。