在Pytorch中正确进行数据加载,拆分和扩充

时间:2019-06-13 14:01:30

标签: neural-network pytorch

该教程似乎没有解释我们应该如何加载,拆分和进行适当的扩增。

让我们有一个由汽车和猫组成的数据集。文件夹结构为:

data
  cat
    0101.jpg
    0201.jpg
    ...
  dogs
    0101.jpg
    0201.jpg
    ...

首先,我通过datasets.ImageFolder函数加载了数据集。 Image Function具有命令“ TRANSFORM”,我们可以在其中设置一些扩充命令,但是我们不想将扩充应用于测试数据集!因此,让我们保持transform = None不变。

data = datasets.ImageFolder(root='data')

显然,我们没有培训和测试文件夹结构,因此我认为一种不错的方法是使用split_dataset function

    train_size = int(split * len(data))
    test_size = len(data) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(data, [train_size, test_size])

现在,让我们以以下方式加载数据。

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                              batch_size=8,
                                              shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=8,
                                              shuffle=True)

如何将转换(数据增强)应用于“ train_loader”图像?

基本上,我需要:1.从上面说明的文件夹结构中加载数据 2.将数据分为测试/培训部分 3.在火车部分应用增强。

2 个答案:

答案 0 :(得分:0)

我不确定是否有建议的方法,但这是解决此问题的方法:

鉴于torch.utils.data.random_split()返回Subset,我们不能(我们不能百分百确定吗?我仔细检查过,我们不能)利用其内部数据集,因为它们相同(唯一的区别在于索引)。在这种情况下,我将实现一个简单的类来应用转换,如下所示:

from torch.utils.data import Dataset

class ApplyTransform(Dataset):
    """
    Apply transformations to a Dataset

    Arguments:
        dataset (Dataset): A Dataset that returns (sample, target)
        transform (callable, optional): A function/transform to be applied on the sample
        target_transform (callable, optional): A function/transform to be applied on the target

    """
    def __init__(self, dataset, transform=None, target_transform=None):
        self.dataset = dataset
        self.transform = transform
        self.target_transform = target_transform
        # yes, you don't need these 2 lines below :(
        if transform is None and target_transform is None:
            print("Am I a joke to you? :)")

    def __getitem__(self, idx):
        sample, target = self.dataset[idx]
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return sample, target

    def __len__(self):
        return len(self.dataset)

然后在将数据集传递到数据加载器之前使用它:

import torchvision.transforms as transforms

train_transform = transforms.Compose([
    transforms.ToTensor(),
    # ...
])
train_dataset = ApplyTransform(train_dataset, transform=train_transform)

# continue with DataLoaders...

答案 1 :(得分:0)

我认为您可以看到此https://gist.github.com/kevinzakka/d33bf8d6c7f06a9d8c76d97a7879f5cb

def get_train_valid_loader(data_dir,
                           batch_size,
                           augment,
                           random_seed,
                           valid_size=0.1,
                           shuffle=True,
                           show_sample=False,
                           num_workers=4,
                           pin_memory=False):
    """
    Utility function for loading and returning train and valid
    multi-process iterators over the CIFAR-10 dataset. A sample
    9x9 grid of the images can be optionally displayed.
    If using CUDA, num_workers should be set to 1 and pin_memory to True.
    Params
    ------
    - data_dir: path directory to the dataset.
    - batch_size: how many samples per batch to load.
    - augment: whether to apply the data augmentation scheme
      mentioned in the paper. Only applied on the train split.
    - random_seed: fix seed for reproducibility.
    - valid_size: percentage split of the training set used for
      the validation set. Should be a float in the range [0, 1].
    - shuffle: whether to shuffle the train/validation indices.
    - show_sample: plot 9x9 sample grid of the dataset.
    - num_workers: number of subprocesses to use when loading the dataset.
    - pin_memory: whether to copy tensors into CUDA pinned memory. Set it to
      True if using GPU.
    Returns
    -------
    - train_loader: training set iterator.
    - valid_loader: validation set iterator.
    """
    error_msg = "[!] valid_size should be in the range [0, 1]."
    assert ((valid_size >= 0) and (valid_size <= 1)), error_msg

    normalize = transforms.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2023, 0.1994, 0.2010],
    )

    # define transforms
    valid_transform = transforms.Compose([
            transforms.ToTensor(),
            normalize,
    ])
    if augment:
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
    else:
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])

    # load the dataset
    train_dataset = datasets.CIFAR10(
        root=data_dir, train=True,
        download=True, transform=train_transform,
    )

    valid_dataset = datasets.CIFAR10(
        root=data_dir, train=True,
        download=True, transform=valid_transform,
    )

    num_train = len(train_dataset)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))

    if shuffle:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    train_idx, valid_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, sampler=train_sampler,
        num_workers=num_workers, pin_memory=pin_memory,
    )
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset, batch_size=batch_size, sampler=valid_sampler,
        num_workers=num_workers, pin_memory=pin_memory,
    )

    # visualize some images
    if show_sample:
        sample_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=9, shuffle=shuffle,
            num_workers=num_workers, pin_memory=pin_memory,
        )
        data_iter = iter(sample_loader)
        images, labels = data_iter.next()
        X = images.numpy().transpose([0, 2, 3, 1])
        plot_images(X, labels)

    return (train_loader, valid_loader)

似乎他使用sampler=train_sampler进行拆分。