如何开始在PyTorch中训练对象检测模型

时间:2020-02-23 05:59:03

标签: python pytorch object-detection torch torchvision

我想在PyTorch中训练自定义对象检测模型。我正在创建一个CustomDataSet类来加载我的数据集。我创建数据的代码如下

class CustomDataset(torch.utils.data.Dataset):

    def __init__(self, root_dir,transform=None):
        self.root = root_dir
        self.imgs = list(sorted(os.listdir(os.path.join(root_dir, "images/"))))
        self.annotations = list(sorted(os.listdir(os.path.join(root_dir, "annotations/"))))

        self._classes = ('__background__',  # always index 0
                         'car','person','bicycle','dog','other')

        self._class_to_ind = {'car':'3', 'person':'1', 'bicycle':'2', 'dog':'18','other':'91'}

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        self.num_classes = 6

        img = os.path.join(self.root, "images/", self.rgb_imgs[idx])
        img = Image.open(img)
        img = np.array(img)
        img = torch.from_numpy(img)

        filename = os.path.join(self.root,'annotations',self.annotations[idx])
        tree = ET.parse(filename)
        objs = tree.findall('object')

        num_objs = len(objs)
        boxes = np.zeros((num_objs, 4), dtype=np.uint16)
        labels = np.zeros((num_objs), dtype=np.float32)
        seg_areas = np.zeros((num_objs), dtype=np.float32)

        boxes = []
        for ix, obj in enumerate(objs):
            bbox = obj.find('bndbox')
            x1 = float(bbox.find('xmin').text)
            y1 = float(bbox.find('ymin').text)
            x2 = float(bbox.find('xmax').text)
            y2 = float(bbox.find('ymax').text)

            cls = self._class_to_ind[obj.find('name').text.lower().strip()]
            boxes.append([x1, y1, x2, y2])
            labels[ix] = cls
            seg_areas[ix] = (x2 - x1 + 1) * (y2 - y1 + 1)

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        seg_areas = torch.as_tensor(seg_areas, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.float32)

        target =  {'boxes': boxes,
                'labels': labels,
                'seg_areas': seg_areas,
                }

        return img,target

我开始训练的主要功能如下

num_classes = 6
model = fasterrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

dataset_train = CustomDataset('images/train/')
dataset_val = CustomDataset('images/val/')

data_loader_train = torch.utils.data.DataLoader(
    dataset_train, batch_size=1, shuffle=True)

data_loader_test = torch.utils.data.DataLoader(
    dataset_val, batch_size=1, shuffle=False)

device = torch.device('cuda')
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
                            momentum=0.9, weight_decay=0.0005)

lr_scheduler = torch.optim.lr_schedler.StepLR(optimizer,
                                               step_size=3,
                                               gamma=0.1)

num_epochs = 10

for epoch in range(num_epochs):

    # Code to start training

我已经定义了模型,为我的训练和验证集创建了一个DataLoader,但是我不确定如何开始训练以及我应该如何准确地为模型提供输入。
我是PyTorch的初学者,如果有人可以帮助我,那就太好了

0 个答案:

没有答案