Pytorch自定义数据集类给出错误的输出

时间:2020-02-19 05:27:22

标签: python pytorch

我正在尝试使用为数据集构建的此类,但它表示它应该是PIL或ndarray。我不太确定这是怎么回事。这是我正在使用的课程

class RotateDataset(Dataset):
    def __init__(self, image_list, size,transform = None):
        self.image_list = image_list
        self.size = size
        self.transform = transform
    def __len__(self):
        return len(self.image_list)
    def __getitem__(self, idx):

        img = cv2.imread(self.image_list[idx])
        image_height, image_width = img.shape[:2]
        print("ID: ", idx)
        if idx % 2 == 0:
            label = 0 # Set label
            # chose negative or positive rotation
            rotation_degree = random.randrange(35, 50, 1)
            posnegrot = np.random.randint(2)
            if posnegrot == 0:
                #positive rotation
                #rotation_matrix = cv2.getRotationMatrix2D((num_cols/2, num_rows/2), rotation_degree, 1)
                #img = cv2.warpAffine(img, rotation_matrix, (num_cols, num_rows))

                img = rotate_image(img, rotation_degree)
                img = crop_around_center(img, *largest_rotated_rect(image_width,
                                                                image_height,
                                                                math.radians(rotation_degree)))
            else:
                # Negative rotation
                rotation_degree = -rotation_degree
                img = crop_around_center(img, *largest_rotated_rect(image_width,
                                                                image_height,
                                                                math.radians(rotation_degree)))

        else:
           label = 1
        img = cv2.resize(img, self.size, cv2.INTER_AREA)
        return self.transform(img), self.transform(label)

它给我的错误是

TypeError: pic should be PIL Image or ndarray. Got class 'int'

它应该给我一个img(张量)和一个标签(张量) 但我不认为它做得正确。

TypeError                                 Traceback (most recent call last)
<ipython-input-34-f47943b2600c> in <module>
      2     train_loss = 0.0
      3     net.train()
----> 4     for image, label in enumerate(train_loader):
      5         if train_on_gpu:
      6             image, label = image.cuda(), label.cuda()

~\Anaconda3\envs\TF2\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
    343 
    344     def __next__(self):
--> 345         data = self._next_data()
    346         self._num_yielded += 1
    347         if self._dataset_kind == _DatasetKind.Iterable and \

~\Anaconda3\envs\TF2\lib\site-packages\torch\utils\data\dataloader.py in _next_data(self)
    383     def _next_data(self):
    384         index = self._next_index()  # may raise StopIteration
--> 385         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    386         if self._pin_memory:
    387             data = _utils.pin_memory.pin_memory(data)

~\Anaconda3\envs\TF2\lib\site-packages\torch\utils\data\_utils\fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

~\Anaconda3\envs\TF2\lib\site-packages\torch\utils\data\_utils\fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

<ipython-input-28-6c77357ff619> in __getitem__(self, idx)
     35             label = 1
     36         img = cv2.resize(img, self.size, cv2.INTER_AREA)
---> 37         return self.transform(img), self.transform(label)

~\Anaconda3\envs\TF2\lib\site-packages\torchvision\transforms\transforms.py in __call__(self, pic)
     99             Tensor: Converted image.
    100         """
--> 101         return F.to_tensor(pic)
    102 
    103     def __repr__(self):

~\Anaconda3\envs\TF2\lib\site-packages\torchvision\transforms\functional.py in to_tensor(pic)
     53     """
     54     if not(_is_pil_image(pic) or _is_numpy(pic)):
---> 55         raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
     56 
     57     if _is_numpy(pic) and not _is_numpy_image(pic):

TypeError: pic should be PIL Image or ndarray. Got <class 'int'>

1 个答案:

答案 0 :(得分:1)

正如评论中所讨论的那样,问题也是在label上应用了转换。相反,label应该简单地写为张量:

return self.transform(img), torch.tensor(label)