CIFAR10数据加载器采样器拆分

时间:2020-10-15 20:40:06

标签: python numpy machine-learning pytorch

我正在尝试拆分CIFAR10的训练数据,因此将训练集中的最后5000个用于验证。我的代码

size = len(CIFAR10_training)
dataset_indices = list(range(size))
val_index = int(np.floor(0.9 * size))
train_idx, val_idx = dataset_indices[:val_index], dataset_indices[val_index:]
train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)

train_dataloader = torch.utils.data.DataLoader(CIFAR10_training,
                                          batch_size=config['batch_size'],
                                          shuffle=False,  sampler = train_sampler)
valid_dataloader = torch.utils.data.DataLoader(CIFAR10_training,
                                           batch_size=config['batch_size'],
                                           shuffle=False,  sampler = val_sampler)
print(len(train_dataloader.dataset),len(valid_dataloader.dataset),

,但最后一个打印语句将打印50000和10000。如果不是45000和5000 当我打印train_idx和val_idx时,它会打印正确的值([0:44999],[45000:49999] 我的代码有什么问题吗

1 个答案:

答案 0 :(得分:0)

我无法复制您的结果,当我执行代码时,打印语句输出的输出值是相同值的两倍:train_CIFAR10中的元素数量。因此,我想您在复制代码时犯了一个错误,实际上为valid_dataloader(或类似的东西)提供了CIFAR10_test作为参数。在下文中,我将假设是这种情况,并且您的打印输出为(50000, 50000),这是Pytorch CIFAR10数据集训练部分的大小。

然后完全可以预期,并且不应该输出(45000,5000)。您要输入train_dataloader.datasetvalid_dataloader.dataset的长度,即基础数据集的长度。对于您的两个加载器,此数据集均为CIFAR10_training。因此,您将获得此数据集大小的两倍(即50000)。

您也不能要求len(train_dataloader),因为这样会产生数据集中的批次数量(大约45000/batch_size)。

如果您需要知道分割的大小,则必须计算采样器的长度:

print(len(train_dataloader.sampler), len(valid_dataloader.sampler))

除此之外,您的代码还不错,您正在正确分割数据。