因此,为了简化起见,我具有以下代码行可从名为“ 0”和“ 1”的两个类加载图像数据集:
train_data = torchvision.datasets.ImageFolder(os.path.join(TRAIN_DATA_DIR), train_transform)
然后以这种方式准备要与模型一起使用的加载器:
train_loader = torch.utils.data.DataLoader(train_data, TRAIN_BATCH_SIZE, shuffle=True)
因此,现在每个图像都与一个类相关联,我想要做的是获取每个图像,并在这两行代码之间对其进行转换,比如说旋转四个角度之一:0、90, 180、270,并将该信息添加为四个类别的附加标签:0、1、2、3。最后,我希望数据集包含旋转的图像,并将两个值的列表作为它们的标签:类别图片和应用的旋转。
我尝试过此操作,没有错误,但是如果我尝试打印标签,则数据集保持不变:
for idx,label in enumerate(train_data.targets):
train_data.targets[idx] = [label, 1]
是否有一种不错的方法,可以直接修改train_data
而不需要自定义数据集?
答案 0 :(得分:0)
是否有一种很好的方法,可以直接修改train_data而无需自定义数据集?
不,没有。如果要使用datasets.ImageFolder
,则必须接受其有限的灵活性。实际上,ImageFolder
只是DatasetFolder
的子类,这几乎就是自定义数据集。您可以在其source code中看到__getItem__
的以下部分:
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
这使得您想要的一切变得不可能,因为您期望的变换应该同时修改图像和目标,这在此处是独立完成的。
因此,首先使Dataset
的子类类似于DatasetFolder
,然后简单地实现自己的变换,该变换将同时获取图像和目标并返回其变换后的值。这只是您可能拥有的转换类的示例,然后需要将其组合成一个函数调用:
class RotateTransform(object):
def __call__(self, image, target):
# Rotate the image randomly and adjust the target accordingly
#...
return image, target
如果这对您来说太麻烦了,那么最好的选择就是提到@jchaykow,它是在运行代码之前简单地修改文件。