如何仅从torchvision.datasets.CIFAR10中提取类的子集?

时间:2019-01-26 16:08:28

标签: pytorch

如何从torchvision.datasets.CIFAR10中仅提取2或3个类?

加载所有10个类的标准方式

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)

1 个答案:

答案 0 :(得分:1)

通过检查CIFAR10中的code,您可以看到数据存储为numpy数组,标签存储为列表。因此,您可以将此子类化并适当过滤两个数组。下面是一个示例:

class SubLoader(torchvision.datasets.CIFAR10):
    def __init__(self, *args, exclude_list=[], **kwargs):
        super(SubLoader, self).__init__(*args, **kwargs)

        if exclude_list == []:
            return

        if self.train:
            labels = np.array(self.train_labels)
            exclude = np.array(exclude_list).reshape(1, -1)
            mask = ~(labels.reshape(-1, 1) == exclude).any(axis=1)

            self.train_data = self.train_data[mask]
            self.train_labels = labels[mask].tolist()
        else:
            labels = np.array(self.test_labels)
            exclude = np.array(exclude_list).reshape(1, -1)
            mask = ~(labels.reshape(-1, 1) == exclude).any(axis=1)

            self.test_data = self.test_data[mask]
            self.test_labels = labels[mask].tolist()