如何获得只有几个类的图像数据集的一小部分?

时间:2020-05-18 08:19:36

标签: python pytorch

我正尝试测试自定义的CNN模型,所以我想在一个很小的子集上训练我的模型,看看它是否可以过拟合。我的问题如下:我怎么只能得到几个类(比如5)训练吗?

  • 这是我的数据加载器:
train_data = CUB200.CUB200(transform=transform_train, train=True)
test_data = CUB200.CUB200(transform=transform_test, train=False)

trainloader = torch.utils.data.DataLoader(dataset=train_data, batch_size=32, shuffle=True, num_workers=0, drop_last=True)
testloader = torch.utils.data.DataLoader(dataset=test_data, batch_size=32, shuffle=True, num_workers=0)
  • 自定义数据:
import os
import pandas as pd
from torchvision.datasets.folder import default_loader
from torch.utils.data import Dataset


base_folder = '../../data/CUB_200_2011/'

class CUB200(Dataset):
    def __init__(self, base_folder=base_folder, transform=None, train=True):
        self.base_folder = base_folder
        self.transform = transform
        self.train = train
        images = pd.read_csv(os.path.join(base_folder, 'images.txt'), sep=' ',
                        names=['img_id', 'filepath'])
        image_class_labels = pd.read_csv(os.path.join(base_folder, 'image_class_labels.txt'),
                                    sep=' ', names=['img_id', 'target'])
        train_test_split = pd.read_csv(os.path.join(base_folder, 'train_test_split.txt'),
                                        sep=' ', names=['img_id', 'is_training_img'])
        data = images.merge(image_class_labels, on='img_id')
        data = data.merge(train_test_split, on='img_id')
        if self.train:
            self.data = data[data.is_training_img == 1]
        else:
            self.data = data[data.is_training_img == 0]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        sample = self.data.iloc[index]
        path = os.path.join(base_folder, 'images', sample.filepath)
        img = default_loader(path)
        # Target start from 1 in data, so shift to 0
        label = sample.target - 1
        if self.transform is not None:
            img = self.transform(img)
        return img, label
  • 图片文件夹:

enter image description here

我知道我们可以在pytorch中使用torch.utils.data.Subset来执行此操作。但是,如果我这样做,则必须计算前5个类的索引。还有其他更便捷的方法吗?

0 个答案:

没有答案