在CIFAR数据集上使用Pytorch创建自定义数据集

时间:2019-11-05 06:24:34

标签: python image deep-learning classification pytorch

我没有使用这些数据集的Pytorch内置API,而是尝试创建自己的数据集并将该数据集馈送到Pytorch的DATASET API和DATALOADER API。但是不知何故,我遇到了一些错误。

我的数据是通过将所有4个火车泡菜合并为一个而创建的。 IMAGES LABELS

创建数据并遵循此[CustomDataset] [3]之后,我编写了以下代码:

import numpy as np
import pickle as pkl
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils


# For custom dataset inherit the parent Dataset class into the child class
class CIFARDataset(Dataset):
    """CIFAR dataset."""

    def __init__(self, pckl_path, transform=None):
        """

        :param pckl_path:
        :param transform:
        """
        " Load the pickle files data"
        pckl_fd = open(pckl_path, "rb")
        self.data_pckl = pkl.load(pckl_fd)

        self.transform = transform

    def __len__(self):
        return len(self.data_pckl)

    def __getitem__(self, idx):
        print("inside __get_item")
        if torch.is_tensor(idx):
            idx = idx.tolist()
        sample = {'image': self.data_pckl['images'][idx], 'label': self.data_pckl['labels'][idx]}
        if self.transform:
            sample = self.transform(sample)

        return sample

class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, sample):
        print("In ToTensor")
        image, label = sample['images'], sample['labels']
        image = image.transpose((2, 0, 1))
        return {'image': torch.from_numpy(image),
                'label': torch.from_numpy(np.ndarray(label))}


dataset= CIFARDataset('cifar/train_set.pickle', transform=transforms.Compose(ToTensor()))
# composed = transforms.Compose([ToTensor()])
# sample = dataset.data_pckl
sample1 = {'images':None, 'labels': None}

data = dataset[0]

运行此命令时,出现以下错误:

错误:

data = dataset[0]
  File "/home/garud/Documents/DSP_notes/Project/create_dataset.py", line 34, in __getitem__
    sample = self.transform(sample)
  File "/home/garud/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py", line 60, in __call__
    for t in self.transforms:
TypeError: 'ToTensor' object is not iterable

我调试并检查的示例是将传递给transform函数的字典。不知道哪里出了问题。

请忠告什么是错误的,以及需要遵循哪些最佳实践才能更好地做到这一点。

1 个答案:

答案 0 :(得分:0)

使用transforms.Compose编写转换时,需要提供转换的列表
试试:

transforms.Compose([ToTensor(), ])

您仍然只提供一个转换,但是它包装在一个列表中。

相关问题