如何在 PyTorch 自定义数据集类中处理 None?

时间:2021-02-25 16:05:19

标签: python computer-vision pytorch

我有一个带有 RGB 图像和 Json 注释的对象检测数据集。我使用自定义 DataLoader 类来读取图像和标签。我面临的一个问题是,如果/当标签不包含某些对象时,我想在训练模型时跳过图像。

例如,如果一张图像不包含属于“汽车”类的任何目标标签,我想跳过它们。在解析我的 Json 注释时,我尝试检查不包含类“Cars”的标签并返回 None。随后,我使用 collat​​e 函数过滤 None 但不幸的是,它不起作用。

import torch
from torch.utils.data.dataset import Dataset
import json
import os
from PIL import Image
from torchvision import transforms
#import cv2
import numpy as np
general_classes = {
    # Cars
    "Toyota Corolla" : 0,
    "VW Golf" : 0,
    "VW Beetle" : 0,

    # Motor-cycles
    "Harley Davidson" : 1,
    "Yamaha YZF-R6" : 1,
}

car_classes={
"Toyota Corolla" : 0,
"VW Golf" : 0,
"VW Beetle" : 0
}

def get_transform(train):
    transforms = []
    # converts the image, a PIL image, into a PyTorch Tensor
    transforms.append(T.ToTensor())
    if train:
        # during training, randomly flip the training images
        # and ground-truth for data augmentation
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)


def my_collate(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)


class FilteredDataset(Dataset):
    # The dataloader will skip the image and corresponding labels based on the dictionary 'car_classes'
    def __init__(self, data_dir, transforms):
        self.data_dir = data_dir
        img_folder_list = os.listdir(self.data_dir)
        self.transforms = transforms

        imgs_list = []
        json_list = []
        self.filter_count=0
        self.filtered_label_list=[]

        for img_path in img_folder_list:
            #img_full_path = self.data_dir + img_path
            img_full_path=os.path.join(self.data_dir,img_path)
            json_file = os.path.join(img_full_path, 'annotations-of-my-images.json')
            img_file = os.path.join(img_full_path, 'Image-Name.png')

            json_list.append(json_file)
            imgs_list.append(img_file)
        self.imgs = imgs_list
        self.annotations = json_list
        total_count=0

        for one_annotation in self.annotations:
            filtered_obj_id=[]
            with open(one_annotation) as f:
                img_annotations = json.load(f)

            parts_list = img_annotations['regions']
            for part in parts_list:
                current_obj_id = part['tags'][0] # bbox label 
                check_obj_id = general_classes[current_obj_id]
                if(check_obj_id==0):
                    subclass_id=car_classes[current_obj_id]
                    filtered_obj_id.append(subclass_id)
                    total_count=total_count+1

            if(len(filtered_obj_id)>0):
                self.filter_count=self.filter_count+1
                self.filtered_label_list.append(one_annotation)

        print("The total number of the objects in all images: ",total_count)


    # get one image and the bboxes,img_id, labels of parts, etc in the image as target.
    def __getitem__(self, idx):

        img_path = self.imgs[idx]
        image_id = torch.tensor([idx])
        
        with open(self.annotations[idx]) as f:
            img_annotations = json.load(f)
        parts_list = img_annotations['regions']
        obj_ids = []
        boxes = []
        for part in parts_list:
            obj_id = part['tags'][0]
            check_obj_id = general_classes[obj_id]
            if(check_obj_id==0):
               obj_id=car_classes[obj_id]
               obj_ids.append(obj_id)
                #print("---------------------------------------------------")
                
        if(len(obj_ids)>0):
            img = Image.open(img_path).convert("RGB")
            labels = torch.as_tensor(obj_ids, dtype = torch.int64)
            target = {}
            target['labels'] = labels
            
            if self.transforms is not None:
                img, target = self.transforms(img, target)
                return img, target
        else:
            return None


    def __len__(self):
        return len(self.filtered_label_list)




train_data_path = "path-to-my-annotation"
# Generators
train_dataset = FilteredDataset(train_data_path,get_transform(train=True))
print("Total files in the train_dataset: ",len(train_dataset))
#print("The first instance in the train dataset : ",train_dataset[0])
#training_generator = torch.utils.data.DataLoader(train_dataset)
training_generator = torch.utils.data.DataLoader(train_dataset,collate_fn=my_collate)
print("\n\n Iterator in action! ")
print("---------------------------------------------------------")
count=0
for img,target in training_generator:
    #print("The img name : ",img[0])
    count=count+1
    print("target name : ",target)
    print("count : ",count)
    print("**************************************************")

但是,我收到以下错误, enter image description here 任何人都可以提出一种跳过不包含特定分类标签的图像的方法吗?

0 个答案:

没有答案
相关问题