如何在PyTorch中拆分火车测试以对数据进行分类

时间:2020-06-08 16:30:53

标签: python-3.x split pytorch training-data dataloader

我有一个包含36406个图像,52个类别的图像数据集。它们是不平衡数据。我想拆分训练和测试数据,以便按数据集中类别的比例进行分配。 例如:

'Tshirts': 3534, 'Shirts': 3213, 'Casual Shoes': 2846,...

因此在火车数据集中,我希望每个类别的20%:

train dataset: 'Tshirts': 3534 * 0.2, 'Shirts': 3213*0.2, 'Casual Shoes': 2846*0.2,...

我不知道如何使用PyTorch来实现它。

我的数据集:

dataset = ImageFolder("/content/gdrive/My Drive/categorized_products"
                      , transform=transform)
classes = dataset.classes

我要使用:

train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, num_workers=16)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=test_sampler, num_workers=16)

0 个答案:

没有答案