在PyTorch中,我试图编写一个类,该类可以使用data
和label
之类的语法分别返回整个dataset.data
和dataset.label
。代码框架如下:
class MyDataset(object):
data = _get_data()
label = _get_label()
def __init__(self, dir, transforms):
self.img_list = ... # all image paths loaded from dir
# do something
def __getitem__(self):
# do something
return data, label
def __len__(self):
return len(self.img_list)
def _get_data():
# do something
def _get_label():
# do something
但是,当我使用dataset.data
和dataset.label
访问相应的变量时,什么也不返回。
我想知道为什么会这样,如何解决这个问题。
感谢您的关注。
我自己解决了这个问题。该解决方案非常简单,只需利用类变量的属性即可。
class FaceDataset(object):
# class variable
data = None
label = None
def __init__(self, root, transforms=None):
# read img_list from root
img_list = ...
self.transforms = ...
FaceDataset.data = FaceDataset._get_data(self.img_list, self.transforms)
FaceDataset.label = FaceDataset._get_label(self.img_list)
@classmethod
def _get_data(cls, img_list, transforms):
data_list = []
for img_path in img_list:
data_list.append(transforms(Image.open(img_path)).unsqueeze(0))
return torch.stack(data_list, dim=0)
@classmethod
def _get_label(cls, img_list):
label = torch.zeros(len(img_list))
for i, img_path in enumerate(img_list):
label[i] = ...
return label
def __getitem__(self, index):
img_path = self.img_list[index]
label = ...
# read image from file
data = Image.open(img_path)
# apply transform defined in __init__
data = self.transforms(data)
return data, label
def __len__(self):
return len(self.img_list)
答案 0 :(得分:1)
在Python上here已回答了在Python中创建自定义数据集的“常规”方法。恰好有一个官方的PyTorch tutorial。
作为一个简单的示例,您可以阅读PyTorch MNIST数据集代码here(此数据集在此PyTorch example code中使用以作进一步说明)。最后,您可以在该Torchvision数据集list中找到其他数据集实现(单击数据集名称,然后单击数据集文档中的“源”按钮,以访问数据集的PyTorch实现)。