如何为每个类构建返回相同数量图像的批次

时间:2021-05-14 14:00:53

标签: python deep-learning computer-vision pytorch

我正在做一个 image retrieval 项目,为了使模型更公平,我想构建返回的批次:

  • 5 images 每个班级,以及
  • 75 images 和每批

我的数据集中总共有 300 classes,因此很明显每个批次中只能包含 15 类图像。数据是平衡的这意味着有相同数量的图像每节课,我使用 pytorch

我已经创建了 pytorch 数据集,我想在我的 ImageFolderLoader class 中添加上述功能,我在下面添加了其代码。

IMG_EXTENSIONS = [
   '.jpg', '.JPG', '.jpeg', '.JPEG',
   '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

def find_classes(dir):
    classes = os.listdir(dir)
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    classes = [clss.split('.')[1] for clss in classes]
    return classes, class_to_idx

def make_dataset(dir, class_to_idx):
    images = []
    for target in os.listdir(dir):
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue
        for filename in os.listdir(d):
            if is_image_file(filename):
                path = '{0}/{1}'.format(target, filename)
                item = (path, class_to_idx[target])
                images.append(item)
                
    return images

def default_loader(path):
    return Image.open(path).convert('RGB')

class ImageFolderLoader(Dataset):
    def __init__(self, root, transform=None, loader=default_loader,):
        classes, class_to_idx = find_classes(root)
        imgs = make_dataset(root, class_to_idx)
        
        self.root = root
        self.imgs = imgs
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.loader = loader
        
    def __getitem__(self, index):
        path, target = self.imgs[index]
        img = self.loader(os.path.join(self.root, path))
        if self.transform is not None:
            img = self.transform(img)
            
        return img, target
    
    def __len__(self):
        return len(self.imgs)

如果有办法做到这一点,请告诉我>。

编辑:- 任何人都想看到此问题的解决方案,我在解决此问题后添加了解决方案 below

2 个答案:

答案 0 :(得分:0)

关于如何实施有一些开放式问题。例如,您是否希望每个类都被平等地代表,而不管该类的实际频率如何?请注意,这可能会以牺牲多数类的性能为代价,在少数类上获得更好的性能。

此外,您希望每个示例每个 epoch 最多使用一次,还是每个 epoch 至少使用一次?

无论如何,使用标准的 getitem 方法可能很难实现这一点,因为它返回的示例不考虑同一批次中返回的其他示例。您可能需要定义一个自定义的 dataloader 对象以确保良好的数据分布和使用属性,这有点遗憾,因为 pytorch 的数据加载器和数据集对象在大多数简单用例中都可以很好且高效地协同工作。也许其他人有使用这些对象的解决方案。

这里有一个解决方案,它在每批之后使用随机抽样和替换,因此不能保证每个示例都会被使用。此外,它使用循环,因此您可能可以通过并行化做得更好。

class ImageFolderLoader(Dataset):
  def __init__(self, root, transform=None, loader=default_loader,):
    classes, class_to_idx = find_classes(root)
    imgs = make_dataset(root, class_to_idx)

    #currently, imgs items are of the form (path,class)

    data_dict = {}
    for item in imgs:
       cls = item[1]
       if cls not in data_dict.keys():
           data_dict[cls] = [item]
       else:
           data_dict[cls].append(item)  
   
    # each class is the key for a list of all items belonging to that class
    self.data_dict = data_dict 

    self.root = root
    self.imgs = imgs
    self.classes = classes
    self.class_to_idx = class_to_idx
    self.transform = transform
    self.loader = loader
    
  def get_batch(self):
    img_batch = []
    label_batch = []
    
    classes = random.sample((0,300),15) 
    for cls in classes:
        class_data = self.data_dict[cls]
        selection = random.sample((0,len(class_data),5)
        for idx in selection:
           img = self.loader(os.path.join(self.root, class_data[idx][0]))
           if self.transform is not None:
               img = self.transform(img)
           img_batch.append(img)
           label_batch.append(cls)
   
    img_batch = torch.stack(img_batch)
    label_batch = torch.stack(label_batch)

    return img_batch, label_batch

  def __len__(self):
    return len(self.imgs)

答案 1 :(得分:0)

我通过在 batch_sampler 模块中包含 DataLoader 解决了这个问题。为此,我使用了 pytorch-balanced-sampler git 项目,它允许对 batch_sampler 进行出色的自定义,您应该访问这个 repo。

我的自定义数据集:

IMG_EXTENSIONS = [
   '.jpg', '.JPG', '.jpeg', '.JPEG',
   '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

def find_classes(dir):
    classes = os.listdir(dir)
    classes.sort()
    class_to_idx = {classes[i]: i for i in range(len(classes))}
    classes = [clss.split('.')[1] for clss in classes]
    return classes, class_to_idx

def make_dataset(dir, class_to_idx):
    images = []
    for target in os.listdir(dir):
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue
        for filename in os.listdir(d):
            if is_image_file(filename):
                path = '{0}/{1}'.format(target, filename)
                item = (path, class_to_idx[target])
                images.append(item)
        
    data_dict = {}
    for item in images:
        cls = item[1]
        if cls not in data_dict.keys():
            data_dict[cls] = [item]
        else:
            data_dict[cls].append(item) 
        
    return images,data_dict

def default_loader(path):
    return Image.open(path).convert('RGB')

class ImageFolderLoader(Dataset):
    def __init__(self, root, transform=None, loader=default_loader):
        classes, class_to_idx = find_classes(root)
        imgs,instance_labels = make_dataset(root, class_to_idx)
        
        
        self.instance_labels = instance_labels
        
        self.root = root
        self.imgs = imgs
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.loader = loader
        
    def __getitem__(self, index):
        path, target = self.imgs[index]
        img = self.loader(os.path.join(self.root, path))
        if self.transform is not None:
            img = self.transform(img)
            
        return img, target

    def __len__(self):
        return len(self.imgs)

然后我使用了 pytorch-balances-sampler 项目中的 SamplerFactory 类,您需要访问此存储库以了解参数,

train_data = ImageFolderLoader(root=TRAIN_PATH, transform=transform)
batch_sampler = SamplerFactory().get(
    class_idxs=my_list,
    batch_size=75,
    n_batches=146,
    alpha=1,
    kind="fixed"
)
相关问题