如何从torchvision.datasets.CIFAR10中仅提取2或3个类?
加载所有10个类的标准方式
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
答案 0 :(得分:1)
通过检查CIFAR10
中的code,您可以看到数据存储为numpy
数组,标签存储为列表。因此,您可以将此子类化并适当过滤两个数组。下面是一个示例:
class SubLoader(torchvision.datasets.CIFAR10):
def __init__(self, *args, exclude_list=[], **kwargs):
super(SubLoader, self).__init__(*args, **kwargs)
if exclude_list == []:
return
if self.train:
labels = np.array(self.train_labels)
exclude = np.array(exclude_list).reshape(1, -1)
mask = ~(labels.reshape(-1, 1) == exclude).any(axis=1)
self.train_data = self.train_data[mask]
self.train_labels = labels[mask].tolist()
else:
labels = np.array(self.test_labels)
exclude = np.array(exclude_list).reshape(1, -1)
mask = ~(labels.reshape(-1, 1) == exclude).any(axis=1)
self.test_data = self.test_data[mask]
self.test_labels = labels[mask].tolist()