CIFAR10 子集的自定义转换 - PyTorch

时间:2021-01-24 15:44:19

标签: python deep-learning pytorch

我正在尝试为 CIFAR10 数据集的一部分创建自定义转换,该转换将图像叠加在数据集上。我能够下载数据并将其分成子集。使用以下代码:

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

traindata = datasets.CIFAR10('./data', train=True, download=True,
                       transform= transform_train)

partitions = 5
traindata_split = torch.utils.data.random_split(traindata, [int(traindata.data.shape[0] / partitions) for _ in range(partitions)])

然后我想修改部分拆分,所以我创建了以下类和函数来使用,如下所示:

class MyDataset(Dataset): # https://discuss.pytorch.org/t/torch-utils-data-dataset-random-split/32209/3 
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform
        
    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y
        
    def __len__(self):
        return len(self.subset)

class ImageSuperImpose(object):
    """  Image input as PIL and output as PIL
        To be used as part of  torchvision.transforms
       Args: p, a threshold value to control image thinning        
    """
    def __init__(self, p=0):
        self.p = p                  
        
    def __call__(self, image):
        img = cv2.imread('img.jpg') 
        img = img('float32')/255
        imgSm = cv2.resize(img,(32,32))
        np_arr = image.cpu().detach().numpy().T
        sample = cv2.addWeighted(np_arr, 1, imgSm, 1, 0)
        sample = sample.T
        t = torch.from_numpy(sample)
        return sample

transform_train2 = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    ImagePoisoning(), 
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

datasetA = MyDataset(
    traindata_split[0], transform= transform_train2
)

test_loader = torch.utils.data.DataLoader(datasetA, batch_size=128, shuffle=True)

但是当我尝试在子集上训练模型时,出现以下错误:

RuntimeError: The size of tensor a (32) must match the size of tensor b (3) at non-singleton dimension 0

** 更新** 这是完整的错误

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-20-7428084b03be> in <module>()
----> 1 train(model, opt, test_loader, 3)

9 frames
<ipython-input-14-fcb03e1d7685> in client_update(client_model, optimizer, train_loader, epoch)
      5     client_model.train()
      6     for e in range(epoch):
----> 7         for batch_idx, (data, target) in enumerate(train_loader):
      8             data, target = data.to(device), target.to(device)
      9             optimizer.zero_grad()

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    433         if self._sampler_iter is None:
    434             self._reset()
--> 435         data = self._next_data()
    436         self._num_yielded += 1
    437         if self._dataset_kind == _DatasetKind.Iterable and \

/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
    473     def _next_data(self):
    474         index = self._next_index()  # may raise StopIteration
--> 475         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    476         if self._pin_memory:
    477             data = _utils.pin_memory.pin_memory(data)

/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

<ipython-input-7-1bde43acaff0> in __getitem__(self, index)
      7         x, y = self.subset[index]
      8         if self.transform:
----> 9             x = self.transform(x)
     10         return x, y
     11 

/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py in __call__(self, img)
     65     def __call__(self, img):
     66         for t in self.transforms:
---> 67             img = t(img)
     68         return img
     69 

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

/usr/local/lib/python3.6/dist-packages/torchvision/transforms/transforms.py in forward(self, tensor)
    224             Tensor: Normalized Tensor image.
    225         """
--> 226         return F.normalize(tensor, self.mean, self.std, self.inplace)
    227 
    228     def __repr__(self):

/usr/local/lib/python3.6/dist-packages/torchvision/transforms/functional.py in normalize(tensor, mean, std, inplace)
    282     if std.ndim == 1:
    283         std = std.view(-1, 1, 1)
--> 284     tensor.sub_(mean).div_(std)
    285     return tensor
    286 

RuntimeError: The size of tensor a (32) must match the size of tensor b (3) at non-singleton dimension 0

0 个答案:

没有答案
相关问题