我正尝试测试自定义的CNN模型,所以我想在一个很小的子集上训练我的模型,看看它是否可以过拟合。我的问题如下:我怎么只能得到几个类(比如5)训练吗?
train_data = CUB200.CUB200(transform=transform_train, train=True)
test_data = CUB200.CUB200(transform=transform_test, train=False)
trainloader = torch.utils.data.DataLoader(dataset=train_data, batch_size=32, shuffle=True, num_workers=0, drop_last=True)
testloader = torch.utils.data.DataLoader(dataset=test_data, batch_size=32, shuffle=True, num_workers=0)
import os
import pandas as pd
from torchvision.datasets.folder import default_loader
from torch.utils.data import Dataset
base_folder = '../../data/CUB_200_2011/'
class CUB200(Dataset):
def __init__(self, base_folder=base_folder, transform=None, train=True):
self.base_folder = base_folder
self.transform = transform
self.train = train
images = pd.read_csv(os.path.join(base_folder, 'images.txt'), sep=' ',
names=['img_id', 'filepath'])
image_class_labels = pd.read_csv(os.path.join(base_folder, 'image_class_labels.txt'),
sep=' ', names=['img_id', 'target'])
train_test_split = pd.read_csv(os.path.join(base_folder, 'train_test_split.txt'),
sep=' ', names=['img_id', 'is_training_img'])
data = images.merge(image_class_labels, on='img_id')
data = data.merge(train_test_split, on='img_id')
if self.train:
self.data = data[data.is_training_img == 1]
else:
self.data = data[data.is_training_img == 0]
def __len__(self):
return len(self.data)
def __getitem__(self, index):
sample = self.data.iloc[index]
path = os.path.join(base_folder, 'images', sample.filepath)
img = default_loader(path)
# Target start from 1 in data, so shift to 0
label = sample.target - 1
if self.transform is not None:
img = self.transform(img)
return img, label
我知道我们可以在pytorch中使用torch.utils.data.Subset
来执行此操作。但是,如果我这样做,则必须计算前5个类的索引。还有其他更便捷的方法吗?