Pytorch自定义randomcrop用于语义分割

时间:2020-09-29 13:09:14

标签: pytorch torchvision

我正在尝试实现自定义数据集加载器。首先,我以(0.98,1,1)之间相同的比例调整图像和标签的大小,然后我以相同的参数随机裁剪图像和标签,以便将它们输入到NN中。但是,我从pytorch函数中得到一个错误。这是我的代码:

class RandomCrop(object):

    def __init__(self, size, padding=None, pad_if_needed=True, fill=0, padding_mode='constant'):
        
        self.size = size
        self.padding = padding
        self.pad_if_needed = pad_if_needed
        self.fill = fill
        self.padding_mode = padding_mode

    @staticmethod
    def get_params(img, output_size):
        
        w, h = img.size
        th, tw = output_size
        if w == tw and h == th:
            return 0, 0, h, w

        i = random.randint(0, h - th)
        j = random.randint(0, w - tw)
        return i, j, th, tw

    def __call__(self, data):
     
        img,mask = data["image"],data["mask"]
       

        # pad the width if needed
        if self.pad_if_needed and img.size[0] < self.size[1]:
            img = F.pad(img, (self.size[1] - img.size[0], 0), self.fill, self.padding_mode)
            mask = F.pad(mask, (self.size[1] - mask.size[0], 0), self.fill, self.padding_mode)
        # pad the height if needed
        if self.pad_if_needed and img.size[1] < self.size[0]:
            img = F.pad(img, (0, self.size[0] - img.size[1]), self.fill, self.padding_mode)
            mask = F.pad(mask, (0, self.size[0] - mask.size[1]), self.fill, self.padding_mode)
       
        i, j, h, w = self.get_params(img, self.size)
        crop_image = transforms.functional.crop(img, i, j, h, w)
        crop_mask = transforms.functional.crop(mask, i, j, h, w)

        return{"image": crop_image, "mask": crop_mask }

这是错误:

AttributeError: 'Image' object has no attribute 'dim'

1 个答案:

答案 0 :(得分:2)

我错误地导入了nn.functional.pad而不是transforms.functional.pad。更改后,一切进展顺利