Pytorch / torchvision-修改数据集对象的图像和标签

时间:2019-10-18 14:13:09

标签: image-processing dataset label pytorch torchvision

因此,为了简化起见,我具有以下代码行可从名为“ 0”和“ 1”的两个类加载图像数据集:

train_data = torchvision.datasets.ImageFolder(os.path.join(TRAIN_DATA_DIR), train_transform)

然后以这种方式准备要与模型一起使用的加载器:

train_loader = torch.utils.data.DataLoader(train_data, TRAIN_BATCH_SIZE, shuffle=True)

因此,现在每个图像都与一个类相关联,我想要做的是获取每个图像,并在这两行代码之间对其进行转换,比如说旋转四个角度之一:0、90, 180、270,并将该信息添加为四个类别的附加标签:0、1、2、3。最后,我希望数据集包含旋转的图像,并将两个值的列表作为它们的标签:类别图片和应用的旋转。

我尝试过此操作,没有错误,但是如果我尝试打印标签,则数据集保持不变:

for idx,label in enumerate(train_data.targets):
    train_data.targets[idx] = [label, 1]

是否有一种不错的方法,可以直接修改train_data而不需要自定义数据集?

1 个答案:

答案 0 :(得分:0)

  

是否有一种很好的方法,可以直接修改train_data而无需自定义数据集?

不,没有。如果要使用datasets.ImageFolder,则必须接受其有限的灵活性。实际上,ImageFolder只是DatasetFolder的子类,这几乎就是自定义数据集。您可以在其source code中看到__getItem__的以下部分:

if self.transform is not None:
    sample = self.transform(sample)
if self.target_transform is not None:
    target = self.target_transform(target)

这使得您想要的一切变得不可能,因为您期望的变换应该同时修改图像和目标,这在此处是独立完成的。

因此,首先使Dataset的子类类似于DatasetFolder,然后简单地实现自己的变换,该变换将同时获取图像和目标并返回其变换后的值。这只是您可能拥有的转换类的示例,然后需要将其组合成一个函数调用:

class RotateTransform(object):
    def __call__(self, image, target):
        # Rotate the image randomly and adjust the target accordingly
        #...

        return image, target

如果这对您来说太麻烦了,那么最好的选择就是提到@jchaykow,它是在运行代码之前简单地修改文件。