PyTorch DataLoader 如何与 PyTorch 数据集交互以转换批次?

时间:2021-02-25 14:16:28

标签: python pytorch pytorch-dataloader

我正在为 NLP 相关任务创建自定义数据集。

在 PyTorch custom datast tutorial 中,我们看到 __getitem__() 方法在返回样本之前为转换留出了空间:

def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
       
        ### SOME DATA MANIPULATION HERE ###

        sample = {'image': image, 'landmarks': landmarks}
        if self.transform:
            sample = self.transform(sample)

        return sample

但是,这里的代码:

        if torch.is_tensor(idx):
            idx = idx.tolist()

暗示应该能够一次检索多个项目,这让我感到疑惑:

  1. 该转换如何处理多个项目?以教程中的自定义转换为例。它们看起来无法在一次调用中应用于一批样本。

  2. 相关,如果转换只能应用于单个样本,DataLoader 如何并行检索一批多个样本并应用所述转换?

2 个答案:

答案 0 :(得分:2)

  1. 这种转换如何处理多个项目?他们通过使用数据加载器处理多个项目。通过使用转换,您可以指定单个数据发射(例如,batch_size=1)应该发生什么。数据加载器获取指定的batch_size并对火炬数据集中的n方法进行__getitem__调用,将变换应用于发送到训练/的每个样本验证。然后它将 n 样本整理到从数据加载器发出的批大小中。

  2. 相关的,如果转换只能应用于单个样本,DataLoader 如何并行检索一批多个样本并应用所述转换?希望以上对您有意义。并行化由 Torch 数据集类和数据加载器完成,您可以在其中指定 num_workers。 Torch 会对数据集进行pickle 并将其传播给worker。

答案 1 :(得分:1)

来自 transforms from torchvision 的文档:

<块引用>

所有转换都接受 PIL 图像、张量图像或一批张量图像作为输入。 Tensor Image 是一个 (C, H, W) 形状的张量,其中 C 是通道数,H 和 W 是图像的高度和宽度。 Batch of Tensor Images 是 (B, C, H, W) 形状的张量,其中 B 是批次中的图像数量。应用于批量张量图像的确定性或随机变换相同地变换该批次的所有图像。

这意味着您可以传递一批图像,并且变换将应用于整个批次,只要它尊重形状。列表索引作用于数据帧中的 iloc,它选择单个索引或它们的列表,返回请求的子集。