size = len(CIFAR10_training)
dataset_indices = list(range(size))
val_index = int(np.floor(0.9 * size))
train_idx, val_idx = dataset_indices[:val_index], dataset_indices[val_index:]
train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)
train_dataloader = torch.utils.data.DataLoader(CIFAR10_training,
batch_size=config['batch_size'],
shuffle=False, sampler = train_sampler)
valid_dataloader = torch.utils.data.DataLoader(CIFAR10_training,
batch_size=config['batch_size'],
shuffle=False, sampler = val_sampler)
print(len(train_dataloader.dataset),len(valid_dataloader.dataset),
,但最后一个打印语句将打印50000和10000。如果不是45000和5000 当我打印train_idx和val_idx时,它会打印正确的值([0:44999],[45000:49999] 我的代码有什么问题吗
答案 0 :(得分:0)
我无法复制您的结果,当我执行代码时,打印语句输出的输出值是相同值的两倍:train_CIFAR10
中的元素数量。因此,我想您在复制代码时犯了一个错误,实际上为valid_dataloader
(或类似的东西)提供了CIFAR10_test
作为参数。在下文中,我将假设是这种情况,并且您的打印输出为(50000, 50000)
,这是Pytorch CIFAR10数据集训练部分的大小。
然后完全可以预期,并且不应该输出(45000,5000)。您要输入train_dataloader.dataset
和valid_dataloader.dataset
的长度,即基础数据集的长度。对于您的两个加载器,此数据集均为CIFAR10_training
。因此,您将获得此数据集大小的两倍(即50000)。
您也不能要求len(train_dataloader)
,因为这样会产生数据集中的批次数量(大约45000/batch_size
)。
如果您需要知道分割的大小,则必须计算采样器的长度:
print(len(train_dataloader.sampler), len(valid_dataloader.sampler))
除此之外,您的代码还不错,您正在正确分割数据。