pytorch创建具有增强图像的数据集

时间:2020-07-07 17:50:31

标签: image-processing deep-learning pytorch

有没有一种方法可以通过创建X张增强图像来生成DataLoader?我目前使用的代码只能创建一个扩展图像

class ImageDataset(data.Dataset):
    
    def __init__(self, root_dir, transform=None):
        
        self.root_dir = root_dir
        self.img_names = os.listdir()
        self.transform = transform
        
    def __getitem__(self, index):
        
        img = Image.open(os.path.join(self.root, self.img_names[index])).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
            
        return img
            
    def __len__(self):
        
        return len(self.img_names)

此外,我想在同一图像的增强图像具有相同标签的地方添加标签

1 个答案:

答案 0 :(得分:0)

有两种等效的方法可以实现。

选项1

我们可以更改数据集类本身以多次提供相同的数据。这可以通过报告更长的长度并使用索引(mod长度)选择图像名称来实现。

class ImageDataset(data.Dataset):
    def __init__(self, root_dir, repetitions=1, transform=None):
        self.root_dir = root_dir
        self.img_names = os.listdir()
        self.transform = transform
        self.repetitions = repetitions
        
    def __getitem__(self, index):
        img = Image.open(os.path.join(self.root,
            self.img_names[index % len(self.img_names)])).convert('RGB')
        if self.transform is not None:
            img = self.transform(img)
        return img
            
    def __len__(self):
        return len(self.img_names) * self.repetitions

选项2

或者,我们可以使用torch.utils.data.Subset并多次指定相同的索引。

# using your original implementation for ImageDataset
dataset = ImageDataset(root, transforms)
dataset = torch.utils.data.Subset(dataset, list(range(len(dataset))) * repetitions)