我正在做一个 image retrieval
项目,为了使模型更公平,我想构建返回的批次:
5 images
每个班级,以及75 images
和每批我的数据集中总共有 300 classes
,因此很明显每个批次中只能包含 15 类图像。数据是平衡的这意味着有相同数量的图像每节课,我使用 pytorch
。
我已经创建了 pytorch 数据集,我想在我的 ImageFolderLoader class
中添加上述功能,我在下面添加了其代码。
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def find_classes(dir):
classes = os.listdir(dir)
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
classes = [clss.split('.')[1] for clss in classes]
return classes, class_to_idx
def make_dataset(dir, class_to_idx):
images = []
for target in os.listdir(dir):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for filename in os.listdir(d):
if is_image_file(filename):
path = '{0}/{1}'.format(target, filename)
item = (path, class_to_idx[target])
images.append(item)
return images
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolderLoader(Dataset):
def __init__(self, root, transform=None, loader=default_loader,):
classes, class_to_idx = find_classes(root)
imgs = make_dataset(root, class_to_idx)
self.root = root
self.imgs = imgs
self.classes = classes
self.class_to_idx = class_to_idx
self.transform = transform
self.loader = loader
def __getitem__(self, index):
path, target = self.imgs[index]
img = self.loader(os.path.join(self.root, path))
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
return len(self.imgs)
如果有办法做到这一点,请告诉我>。
编辑:- 任何人都想看到此问题的解决方案,我在解决此问题后添加了解决方案 below。
答案 0 :(得分:0)
关于如何实施有一些开放式问题。例如,您是否希望每个类都被平等地代表,而不管该类的实际频率如何?请注意,这可能会以牺牲多数类的性能为代价,在少数类上获得更好的性能。
此外,您希望每个示例每个 epoch 最多使用一次,还是每个 epoch 至少使用一次?
无论如何,使用标准的 getitem 方法可能很难实现这一点,因为它返回的示例不考虑同一批次中返回的其他示例。您可能需要定义一个自定义的 dataloader
对象以确保良好的数据分布和使用属性,这有点遗憾,因为 pytorch 的数据加载器和数据集对象在大多数简单用例中都可以很好且高效地协同工作。也许其他人有使用这些对象的解决方案。
这里有一个解决方案,它在每批之后使用随机抽样和替换,因此不能保证每个示例都会被使用。此外,它使用循环,因此您可能可以通过并行化做得更好。
class ImageFolderLoader(Dataset):
def __init__(self, root, transform=None, loader=default_loader,):
classes, class_to_idx = find_classes(root)
imgs = make_dataset(root, class_to_idx)
#currently, imgs items are of the form (path,class)
data_dict = {}
for item in imgs:
cls = item[1]
if cls not in data_dict.keys():
data_dict[cls] = [item]
else:
data_dict[cls].append(item)
# each class is the key for a list of all items belonging to that class
self.data_dict = data_dict
self.root = root
self.imgs = imgs
self.classes = classes
self.class_to_idx = class_to_idx
self.transform = transform
self.loader = loader
def get_batch(self):
img_batch = []
label_batch = []
classes = random.sample((0,300),15)
for cls in classes:
class_data = self.data_dict[cls]
selection = random.sample((0,len(class_data),5)
for idx in selection:
img = self.loader(os.path.join(self.root, class_data[idx][0]))
if self.transform is not None:
img = self.transform(img)
img_batch.append(img)
label_batch.append(cls)
img_batch = torch.stack(img_batch)
label_batch = torch.stack(label_batch)
return img_batch, label_batch
def __len__(self):
return len(self.imgs)
答案 1 :(得分:0)
我通过在 batch_sampler
模块中包含 DataLoader
解决了这个问题。为此,我使用了 pytorch-balanced-sampler git 项目,它允许对 batch_sampler 进行出色的自定义,您应该访问这个 repo。
我的自定义数据集:
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def find_classes(dir):
classes = os.listdir(dir)
classes.sort()
class_to_idx = {classes[i]: i for i in range(len(classes))}
classes = [clss.split('.')[1] for clss in classes]
return classes, class_to_idx
def make_dataset(dir, class_to_idx):
images = []
for target in os.listdir(dir):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for filename in os.listdir(d):
if is_image_file(filename):
path = '{0}/{1}'.format(target, filename)
item = (path, class_to_idx[target])
images.append(item)
data_dict = {}
for item in images:
cls = item[1]
if cls not in data_dict.keys():
data_dict[cls] = [item]
else:
data_dict[cls].append(item)
return images,data_dict
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolderLoader(Dataset):
def __init__(self, root, transform=None, loader=default_loader):
classes, class_to_idx = find_classes(root)
imgs,instance_labels = make_dataset(root, class_to_idx)
self.instance_labels = instance_labels
self.root = root
self.imgs = imgs
self.classes = classes
self.class_to_idx = class_to_idx
self.transform = transform
self.loader = loader
def __getitem__(self, index):
path, target = self.imgs[index]
img = self.loader(os.path.join(self.root, path))
if self.transform is not None:
img = self.transform(img)
return img, target
def __len__(self):
return len(self.imgs)
然后我使用了 pytorch-balances-sampler 项目中的 SamplerFactory
类,您需要访问此存储库以了解参数,
train_data = ImageFolderLoader(root=TRAIN_PATH, transform=transform)
batch_sampler = SamplerFactory().get(
class_idxs=my_list,
batch_size=75,
n_batches=146,
alpha=1,
kind="fixed"
)