嗨,现在我收到了数据加载代码,我不确定如何将其拆分为训练和测试数据。谁能给我建议,这是我的数据加载代码。
def __init__(self, root, specific_folder, img_extension, preprocessing_method=None, crop_size=(96, 112),train = True):
"""
Dataloader of the LFW dataset.
root: path to the dataset to be used.
specific_folder: specific folder inside the same dataset.
img_extension: extension of the dataset images.
preprocessing_method: string with the name of the preprocessing method.
crop_size: retrieval network specific crop size.
"""
self.preprocessing_method = preprocessing_method
self.crop_size = crop_size
self.imgl_list = []
self.classes = []
self.people = []
self.model_align = None
self.arr = []
# read the file with the names and the number of images of each people in the dataset
with open(os.path.join(root, 'people.txt')) as f:
people = f.read().splitlines()[1:]
# get only the people that have more than 20 images
for p in people:
p = p.split('\t')
if len(p) > 1:
if int(p[1]) >= 20:
for num_img in range(1, int(p[1]) + 1):
self.imgl_list.append(os.path.join(root, specific_folder, p[0], p[0] + '_' +
'{:04}'.format(num_img) + '.' + img_extension))
self.classes.append(p[0])
self.people.append(p[0])
le = preprocessing.LabelEncoder()
self.classes = le.fit_transform(self.classes)
print(len(self.imgl_list), len(self.classes), len(self.people))
def __getitem__(self, index):
imgl = imageio.imread(self.imgl_list[index])
cl = self.classes[index]
# if image is grayscale, transform into rgb by repeating the image 3 times
if len(imgl.shape) == 2:
imgl = np.stack([imgl] * 3, 2)
imgl, bb = preprocess(imgl, self.preprocessing_method, crop_size=self.crop_size,
is_processing_dataset=True, return_only_largest_bb=True, execute_default=True)
# append image with its reverse
imglist = [imgl, imgl[:, ::-1, :]]
# normalization
for i in range(len(imglist)):
imglist[i] = (imglist[i] - 127.5) / 128.0
imglist[i] = imglist[i].transpose(2, 0, 1)
imgs = [torch.from_numpy(i).float() for i in imglist]
return imgs, cl, imgl, bb, self.imgl_list[index], self.people[index]
def __len__(self):
return len(self.imgl_list)
我需要将其中的数据分成20%和80%的数据,这样我才可以测试我的模块,已经快一个星期了,但仍然完全不知道如何去做,如果有人可以提供帮助,将不胜感激:
答案 0 :(得分:0)
通常使用PyTorch:
import torch
import numpy as np
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler
dataset = yourdatahere
batch_size = 16 #change to whatever you'd like it to be
test_split = .2
shuffle_dataset = True
random_seed= 42
# Creating data indices for training and validation splits:
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(test_split * dataset_size))
if shuffle_dataset :
np.random.seed(random_seed)
np.random.shuffle(indices)
train_indices, test_indices = indices[split:], indices[:split]
# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
sampler=train_sampler)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
sampler=test_sampler)
# Usage Example:
num_epochs = 10
for epoch in range(num_epochs):
# Train:
for batch_index, (faces, labels) in enumerate(train_loader):
# ...
请注意,您还应该将训练数据分为训练+验证数据。您可以使用上面相同的逻辑。