批量加载数据是否可变?

时间:2018-07-29 23:12:40

标签: python-3.x image-processing pytorch

我目前正在研究基于补丁的超分辨率。大多数论文将图像分成较小的补丁,然后将补丁用作模型的输入。我能够使用自定义数据加载器创建补丁。代码如下:

import torch.utils.data as data
from torchvision.transforms import CenterCrop, ToTensor, Compose, ToPILImage, Resize, RandomHorizontalFlip, RandomVerticalFlip
from os import listdir
from os.path import join
from PIL import Image
import random
import os
import numpy as np
import torch

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg", ".bmp"])

class TrainDatasetFromFolder(data.Dataset):
    def __init__(self, dataset_dir, patch_size, is_gray, stride):
        super(TrainDatasetFromFolder, self).__init__()
        self.imageHrfilenames = []
        self.imageHrfilenames.extend(join(dataset_dir, x)
                                     for x in sorted(listdir(dataset_dir)) if is_image_file(x))
        self.is_gray = is_gray
        self.patchSize = patch_size
        self.stride = stride

    def _load_file(self, index):
        filename = self.imageHrfilenames[index]
        hr = Image.open(self.imageHrfilenames[index])
        downsizes = (1, 0.7, 0.45)
        downsize = 2
        w_ = int(hr.width * downsizes[downsize])
        h_ = int(hr.height * downsizes[downsize])
        aug = Compose([Resize([h_, w_], interpolation=Image.BICUBIC),
                       RandomHorizontalFlip(),
                       RandomVerticalFlip()])

        hr = aug(hr)
        rv = random.randint(0, 4)
        hr = hr.rotate(90*rv, expand=1)
        filename = os.path.splitext(os.path.split(filename)[-1])[0]
        return hr, filename

    def _patching(self, img):

        img = ToTensor()(img)
        LR_ = Compose([ToPILImage(), Resize(self.patchSize//2, interpolation=Image.BICUBIC), ToTensor()])

        HR_p, LR_p = [], []
        for i in range(0, img.shape[1] - self.patchSize, self.stride):
            for j in range(0, img.shape[2] - self.patchSize, self.stride):
                temp = img[:, i:i + self.patchSize, j:j + self.patchSize]
                HR_p += [temp]
                LR_p += [LR_(temp)]

        return torch.stack(LR_p),torch.stack(HR_p)

    def __getitem__(self, index):
        HR_, filename = self._load_file(index)
        LR_p, HR_p = self._patching(HR_)
        return LR_p, HR_p

    def __len__(self):
        return len(self.imageHrfilenames)

假设批处理大小为1,则它拍摄一张图像并给出大小为[x,3,patchsize,patchsize]的输出。当批处理大小为2时,我将有两个不同的输出,大小为[x,3,patchsize,patchsize](例如,图像1可以给出[50,3,patchsize,patchsize],图像2可以给出[75,3,patchsize,patchsize])。为此,需要一个自定义的整理函数,该函数沿维度0堆叠这两个输出。整理函数如下:

def my_collate(batch):
    data = torch.cat([item[0] for item in batch],dim = 0)
    target = torch.cat([item[1] for item in batch],dim = 0)

    return [data, target]

此整理函数沿x串联(从上面的示例中,我最终得到[125,3,patchsize,pathsize]。出于训练的目的,我需要使用最小批量大小为25训练模型。是否有任何方法或函数可以我可以使用必要数量的图像作为数据输入到数据加载器,直接从数据加载器直接获取大小为[25 , 3, patchsize, pathsize]的输出吗?

3 个答案:

答案 0 :(得分:2)

(相关,但不完全是主题)

对于批量大小调整,您可以使用this repo中所示的代码。出于不同的目的(最大程度地利用GPU内存)实施了该实现,但将其转化为问题并不难。

该代码进行批处理适应和批处理欺骗。

答案 1 :(得分:1)

以下代码段适用于您的目的。

首先,我们定义一个ToyDataset,它接收tensors的张量(variable length in dimension 0)的列表。这类似于您的数据集返回的样本。

import torch
from torch.utils.data import Dataset
from torch.utils.data.sampler import RandomSampler

class ToyDataset(Dataset):
    def __init__(self, tensors):
        self.tensors = tensors

    def __getitem__(self, index):
        return self.tensors[index]

    def __len__(self):
        return len(tensors)

第二,我们定义一个自定义数据加载器。创建数据集和数据加载器的通常的Pytorch二分法大致如下:有一个索引dataset,您可以向其传递索引,它从数据集中返回相关的样本。有一个sampler会产生一个索引,有不同的绘制索引的策略会导致不同的采样器。 batch_sampler使用采样器一次绘制多个索引(与batch_size指定的数目相同)。有一个dataloader结合了采样器和数据集,可以让您遍历一个数据集,重要的是数据加载器还拥有一个函数(collate_fn),该函数指定如何使用索引来从数据集中检索多个样本batch_sampler应该合并。对于您的用例,通常的PyTorch二分法效果不佳,因为除了绘制固定数量的索引外,我们还需要绘制索引,直到与索引关联的对象超过所需的累积大小为止。这意味着我们需要立即检查对象,并使用此知识来决定是退还批次还是保留工程图索引。这是下面的自定义数据加载器的作用:

class CustomLoader(object):

    def __init__(self, dataset, my_bsz, drop_last=True):
        self.ds = dataset
        self.my_bsz = my_bsz
        self.drop_last = drop_last
        self.sampler = RandomSampler(dataset)

    def __iter__(self):
        batch = torch.Tensor()
        for idx in self.sampler:
            batch = torch.cat([batch, self.ds[idx]])
            while batch.size(0) >= self.my_bsz:
                if batch.size(0) == self.my_bsz:
                    yield batch
                    batch = torch.Tensor()
                else:
                    return_batch, batch = batch.split([self.my_bsz,batch.size(0)-self.my_bsz])
                    yield return_batch
        if batch.size(0) > 0 and not self.drop_last:
            yield batch

这里,我们遍历数据集,绘制索引并加载关联的对象后,将其连接到我们之前绘制的张量(batch)。我们一直这样做,直到达到所需的大小,这样我们才能切出并批量生产。我们将行保留在batch中,但没有产生。因为单个实例可能超过了所需的batch_size,所以我们使用while loop

您可以修改此最小的CustomDataloader以添加PyTorch数据加载器样式的更多功能。也不需要使用RandomSampler提取索引,其他索引也可以很好地工作。如果您的数据很大,例如通过使用列表并跟踪其张量的累积长度,还可以避免重复出现。

以下是一个示例,演示了它的工作原理:

patch_size = 5
channels = 3
dim0sizes = torch.LongTensor(100).random_(1, 100)
data = torch.randn(size=(dim0sizes.sum(), channels, patch_size, patch_size))
tensors = torch.split(data, list(dim0sizes))

ds = ToyDataset(tensors)
dl = CustomLoader(ds, my_bsz=250, drop_last=False)
for i in dl:
    print(i.size(0))

答案 2 :(得分:0)

为了改进之前的答案,我找到了一个 repo,它使用 DataManger 来实现不同的补丁大小和批量大小。它基本上是启动具有不同设置的不同数据加载器,并使用 set_epoch 函数为给定的 epoch 设置适当的数据加载器。