有没有一种方法可以通过创建X张增强图像来生成DataLoader?我目前使用的代码只能创建一个扩展图像
class ImageDataset(data.Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.img_names = os.listdir()
self.transform = transform
def __getitem__(self, index):
img = Image.open(os.path.join(self.root, self.img_names[index])).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img
def __len__(self):
return len(self.img_names)
此外,我想在同一图像的增强图像具有相同标签的地方添加标签
答案 0 :(得分:0)
有两种等效的方法可以实现。
我们可以更改数据集类本身以多次提供相同的数据。这可以通过报告更长的长度并使用索引(mod长度)选择图像名称来实现。
class ImageDataset(data.Dataset):
def __init__(self, root_dir, repetitions=1, transform=None):
self.root_dir = root_dir
self.img_names = os.listdir()
self.transform = transform
self.repetitions = repetitions
def __getitem__(self, index):
img = Image.open(os.path.join(self.root,
self.img_names[index % len(self.img_names)])).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img
def __len__(self):
return len(self.img_names) * self.repetitions
或者,我们可以使用torch.utils.data.Subset
并多次指定相同的索引。
# using your original implementation for ImageDataset
dataset = ImageDataset(root, transforms)
dataset = torch.utils.data.Subset(dataset, list(range(len(dataset))) * repetitions)