我目前正在使用Pytorch学习深度学习,并使用Cifar 10
数据集进行了一些实验。其中有10个班级,每个班级都有5000张测试图像。我只想使用60%的狗和鹿类数据以及100%的其他类数据。
根据我的理解,我需要使用自定义数据集。但我实际上无法弄清楚。如果您可以共享任何想法,示例代码或链接都将对我有所帮助。
答案 0 :(得分:1)
您可以像这样使用Subset:
from torchvision.datasets import CIFAR10
from torch.utils.data import Subset
ds = CIFAR10('~/.torch/data/', train=True, download=True)
dog_indices, deer_indices, other_indices = [], [], []
dog_idx, deer_idx = ds.class_to_idx['dog'], ds.class_to_idx['deer']
for i in range(len(ds)):
current_class = ds[i][1]
if current_class == dog_idx:
dog_indices.append(i)
elif current_class == deer_idx:
deer_indices.append(i)
else:
other_indices.append(i)
dog_indices = dog_indices[:int(0.6 * len(dog_indices))]
deer_indices = deer_indices[:int(0.6 * len(deer_indices))]
new_dataset = Subset(ds, dog_indices+deer_indices+other_indices)