pad_sequence无法解决问题“张量()的扩展大小必须与非单维度2上的现有大小()相匹配。”

时间:2020-01-13 15:37:56

标签: pytorch

我正在尝试将collat​​e_fn添加到我的数据加载器中,以加载可变大小的图像。但是它不断提出错误消息: “张量(1444)的扩展大小必须与非单维度2上的现有大小(1936)相匹配。目标大小:[3、1444、1444]。张量大小:[3、1296、1936]”,这意味着pad_sequence无法正常工作。任何帮助将不胜感激。

#collate function
def my_collate(batch):
    # batch contains a list of tuples of structure (sequence, target)
    targets = [item[1]  for item in batch] # list of labels
    data = [item[0] for item in batch]
    data = pad_sequence(data, batch_first= False)
    return [data, targets]

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]) 
#training
trainset = torchvision.datasets.ImageFolder(root='/content/output/train', transform = transform, target_transform=None) #transform here can crop image
trainloader = torch.utils.data.DataLoader(trainset, batch_size = 8,shuffle = False, num_workers = 0,collate_fn=my_collate,pin_memory=True)

RuntimeError                              Traceback (most recent call last)
<ipython-input-77-0efe18aac68b> in <module>()
     10   correct = 0
     11   total = 0
---> 12   for i, data in enumerate(trainloader, 0):
     13       inputs ,labels = data #get inputs
     14       # inputs = torch.FloatTensor(inputs)

3 frames
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    344     def __next__(self):
    345         index = self._next_index()  # may raise StopIteration
--> 346         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    347         if self._pin_memory:
    348             data = _utils.pin_memory.pin_memory(data)

/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     45         else:
     46             data = self.dataset[possibly_batched_index]
---> 47         return self.collate_fn(data)

<ipython-input-76-d9f902bb6b0f> in my_collate(batch)
     15     print(len(data[0][1][0]))
     16     print(len(data[0][2]))
---> 17     data = pad_sequence(data, batch_first= False)
     18     return [data, targets]
     19 

/usr/local/lib/python3.6/dist-packages/torch/nn/utils/rnn.py in pad_sequence(sequences, batch_first, padding_value)
    389             out_tensor[i, :length, ...] = tensor
    390         else:
--> 391             out_tensor[:length, i, ...] = tensor
    392 
    393     return out_tensor

0 个答案:

没有答案