pytorch batch_size没有返回正确的批处理?

时间:2020-07-12 20:14:47

标签: deep-learning pytorch

无论我为batch_size设置什么,batch_size默认为1。这是我的代码

train_dataset = DataLoader(dataset=dataset,
                       batch_size=4,
                       shuffle=True,
                       num_workers=0)

并且数据集是如下的自定义数据集

class ImageDataset(data.Dataset):

def __init__(self, root_dir, num_augments=2, transform=None):
    
    self.root_dir = root_dir
    self.img_names = os.listdir(root_dir)[::600]
    self.num_augments = num_augments
    self.transform = transform
    
def __getitem__(self, index):
    
    output = []
    img = Image.open(self.root_dir + '/' + self.img_names[index]).convert('RGB')
        
    for i in range(self.num_augments):
        if self.transform is not None:
            img_transform = self.transform(img)
            
        output.append(img_transform)
        
    output = torch.stack(output, axis=0)
        
    return output
        
def __len__(self):
    
    return len(self.img_names)

我希望每个批次的大小为[batch,num_augments,3,高度,宽度],但是无论我的批次大小如何,我都会得到[1,num_augments,3,高度,宽度]。

0 个答案:

没有答案