RuntimeError:/pytorch/torch/lib/TH/generic/THTensorMath.c:2864中的张量大小不一致

时间:2018-05-21 18:28:11

标签: python pytorch

我正在尝试构建一个数据加载器,这就是它的样子

<@...>

当我尝试运行这段代码时,我得到了这个错误,我确实得到了错误的性质,我的所有图像可能都不是相同的形状,而且我的图像形状并不完全相同,但是如果我没有错,错误只会出现在我将它们送到网络时,因为图像都是不同的形状,但为什么它会在这里抛出错误呢? 关于我可能出错的地方的任何建议都会非常有帮助, 如果需要,我很乐意提供任何额外信息,

由于

`class WhaleData(Dataset):
def __init__(self, data_file, root_dir , transform = None):
    self.csv_file = pd.read_csv(data_file)
    self.root_dir = root_dir
    self.transform = transforms.Resize(224)

def __len__(self):
    return len(os.listdir(self.root_dir))

def __getitem__(self, index):
    image = os.path.join(self.root_dir, self.csv_file['Image'][index])
    image = Image.open(image)
    image = self.transform(image)
    image = np.array(image)
    label  = self.csv_file['Image'][index]
    sample = {'image': image, 'label':label}
    return sample

trainset  = WhaleData(data_file = '/mnt/55-91e8-b2383e89165f/Ryan/1234/train.csv', 
     root_dir = '/mnt/4d55-91e8-b2383e89165f/Ryan/1234/train')
train_loader = torch.utils.data.DataLoader(trainset , batch_size = 4, shuffle =True,num_workers= 2)
for i, batch in enumerate(train_loader):
      (i, batch)

1 个答案:

答案 0 :(得分:1)

当PyTorch尝试将图像堆叠成单个批量张量时(参见跟踪中的torch.stack([torch.from_numpy(b) for b in batch], 0)),会出现错误。正如您所提到的,由于图像具有不同的形状,因此堆叠失败(即,如果所有这些张量都具有形状(B, H, W)),则只能通过堆叠B张量来创建张量(H, W)。 / p>

注意:我不完全确定,但为batch_size=1设置torch.utils.data.DataLoader(...)可能会删除此特定错误,因为它可能不再需要调用torch.stack()