Pytorch Dataloader未将数据拆分为批处理

时间:2020-06-11 08:05:19

标签: python machine-learning computer-vision pytorch

我有这样的数据集类:

class LoadDataset(Dataset):
    def __init__(self, data, label):
        self.data = data
        self.label = label
    def __len__(self):
        dlen = len(self.data)
        return dlen
    def __getitem__(self, index):
        return self.data, self.label

然后我加载形状为[485、1、32、32]的图像数据集

train_dataset = LoadDataset(xtrain, ytrain)
print(len(train_dataset))
# output 485

然后我用DataLoader

加载数据
train_loader = DataLoader(train_dataset, batch_size=32)

然后我迭代数据:

for epoch in range(num_epoch):
        for inputs, labels in train_loader:   
            print(inputs.shape)

输出显示torch.Size([32, 485, 1, 32, 32]),应为torch.Size([32, 1, 32, 32])

有人可以帮助我吗?

1 个答案:

答案 0 :(得分:1)

__getitem__方法应该返回1个数据,而您全部返回了。

尝试一下:

class LoadDataset(Dataset):
    def __init__(self, data, label):
        self.data = data
        self.label = label
    def __len__(self):
        dlen = len(self.data)
        llen = len(self.label)  # different here
        return min(dlen, llen)  # different here
    def __getitem__(self, index):
        return self.data[index], self.label[index]  # different here