Cifar10数据集:从类中读取一定数量的图像

时间:2019-12-16 07:03:50

标签: python deep-learning dataset pytorch

我目前正在使用Pytorch学习深度学习,并使用Cifar 10数据集进行了一些实验。其中有10个班级,每个班级都有5000张测试图像。我只想使用60%的狗和鹿类数据以及100%的其他类数据。

根据我的理解,我需要使用自定义数据集。但我实际上无法弄清楚。如果您可以共享任何想法,示例代码或链接都将对我有所帮助。

1 个答案:

答案 0 :(得分:1)

您可以像这样使用Subset

from torchvision.datasets import CIFAR10
from torch.utils.data import Subset

ds = CIFAR10('~/.torch/data/', train=True, download=True)
dog_indices, deer_indices, other_indices = [], [], []
dog_idx, deer_idx = ds.class_to_idx['dog'], ds.class_to_idx['deer']

for i in range(len(ds)):
  current_class = ds[i][1]
  if current_class == dog_idx:
    dog_indices.append(i)
  elif current_class == deer_idx:
    deer_indices.append(i)
  else:
    other_indices.append(i)
dog_indices = dog_indices[:int(0.6 * len(dog_indices))]
deer_indices = deer_indices[:int(0.6 * len(deer_indices))]
new_dataset = Subset(ds, dog_indices+deer_indices+other_indices)