我正在使用PyTorch进行语义分割,但是我遇到了一个问题,因为我正在使用图像和它们的标签。我想执行数据扩充,例如RandomHorizontalFlip和RandomCrop等。
这是我的代码,请检查并告知我如何在提供的代码中嵌入以下操作。
import torchvision.transforms.functional as F
class ToTensor(object):
def __call__(self, sample):
image, label = sample['image'], sample['label']
return {'image': F.to_tensor(image), 'label': F.to_tensor(label)}
my_transform = transforms.Compose([ ToTensor() ])
dataset = Mydataset(image_dir, label_dir, transform = my_transform)
# Print dataset output
dataset[1]
{'image': tensor([[[0.0824, 0.0824, 0.0824, ..., 0.0824, 0.0824, 0.0824],
[0.0824, 0.0824, 0.0824, ..., 0.0824, 0.0824, 0.0824],
[0.0824, 0.0824, 0.0824, ..., 0.0824, 0.0824, 0.0824],
...,
[0.0902, 0.0902, 0.0902, ..., 0.0824, 0.0824, 0.0824],
[0.0824, 0.0824, 0.0824, ..., 0.0824, 0.0824, 0.0824],
[0.0745, 0.0745, 0.0745, ..., 0.0824, 0.0824, 0.0824]],
[[0.0824, 0.0824, 0.0824, ..., 0.0824, 0.0824, 0.0824],
[0.0824, 0.0824, 0.0824, ..., 0.0824, 0.0824, 0.0824],
[0.0824, 0.0824, 0.0824, ..., 0.0824, 0.0824, 0.0824],
...,
[0.0902, 0.0902, 0.0902, ..., 0.0824, 0.0824, 0.0824],
[0.0824, 0.0824, 0.0824, ..., 0.0824, 0.0824, 0.0824],
[0.0745, 0.0745, 0.0745, ..., 0.0824, 0.0824, 0.0824]],
[[0.0824, 0.0824, 0.0824, ..., 0.0824, 0.0824, 0.0824],
[0.0824, 0.0824, 0.0824, ..., 0.0824, 0.0824, 0.0824],
[0.0824, 0.0824, 0.0824, ..., 0.0824, 0.0824, 0.0824],
...,
[0.0902, 0.0902, 0.0902, ..., 0.0824, 0.0824, 0.0824],
[0.0824, 0.0824, 0.0824, ..., 0.0824, 0.0824, 0.0824],
[0.0745, 0.0745, 0.0745, ..., 0.0824, 0.0824, 0.0824]]]),
'label': tensor([[[0.0824, 0.0824, 0.0824, ..., 0.0824, 0.0824, 0.0824],
[0.0824, 0.0824, 0.0824, ..., 0.0824, 0.0824, 0.0824],
[0.0824, 0.0824, 0.0824, ..., 0.0824, 0.0824, 0.0824],
...,
[0.0902, 0.0902, 0.0902, ..., 0.0824, 0.0824, 0.0824],
[0.0824, 0.0824, 0.0824, ..., 0.0824, 0.0824, 0.0824],
[0.0745, 0.0745, 0.0745, ..., 0.0824, 0.0824, 0.0824]]])}
答案 0 :(得分:0)
您应该使用transforms.Compose()方法来组合不同的转换。例如:
my_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(20),
transforms.RandomRotation(10),
transforms.RandomCrop((512, 512)),
transforms.ColorJitter(brightness=0.2, saturation=0.2, contrast=0.2),
transforms.ToTensor()
])
dataset = Mydataset(image_dir, label_dir, transform = my_transforms)