在PyTorch中创建自定义数据集

时间:2019-10-22 03:00:21

标签: python class object pytorch

问题

在PyTorch中,我试图编写一个类,该类可以使用datalabel之类的语法分别返回整个dataset.datadataset.label。代码框架如下:

class MyDataset(object):
  data = _get_data()
  label = _get_label()
  def __init__(self, dir, transforms):
    self.img_list = ... # all image paths loaded from dir
    # do something 

  def __getitem__(self):
    # do something
    return data, label

  def __len__(self):
    return len(self.img_list)

  def _get_data():
    # do something

  def _get_label():
    # do something

但是,当我使用dataset.datadataset.label访问相应的变量时,什么也不返回。

我想知道为什么会这样,如何解决这个问题。

编辑

感谢您的关注。

我自己解决了这个问题。该解决方案非常简单,只需利用类变量的属性即可。

class FaceDataset(object):
    # class variable
    data = None
    label = None

    def __init__(self, root, transforms=None):
        # read img_list from root
        img_list = ...
        self.transforms = ...
        FaceDataset.data = FaceDataset._get_data(self.img_list, self.transforms)
        FaceDataset.label = FaceDataset._get_label(self.img_list)

    @classmethod
    def _get_data(cls, img_list, transforms):
        data_list = []
        for img_path in img_list:
            data_list.append(transforms(Image.open(img_path)).unsqueeze(0))
        return torch.stack(data_list, dim=0)

    @classmethod
    def _get_label(cls, img_list):
        label = torch.zeros(len(img_list))
        for i, img_path in enumerate(img_list):
            label[i] = ...
        return label

    def __getitem__(self, index):
        img_path = self.img_list[index]
        label = ...

        # read image from file
        data = Image.open(img_path)
        # apply transform defined in __init__
        data = self.transforms(data)

        return data, label

    def __len__(self):
        return len(self.img_list)

1 个答案:

答案 0 :(得分:1)

在Python上here已回答了在Python中创建自定义数据集的“常规”方法。恰好有一个官方的PyTorch tutorial

作为一个简单的示例,您可以阅读PyTorch MNIST数据集代码here(此数据集在此PyTorch example code中使用以作进一步说明)。最后,您可以在该Torchvision数据集list中找到其他数据集实现(单击数据集名称,然后单击数据集文档中的“源”按钮,以访问数据集的PyTorch实现)。