如何使用Pytorch将MNIST数据集拆分为分布式节点的多个子集?

时间:2019-07-26 10:34:08

标签: pytorch distributed mnist dataloader

我正在为torchvision.datasets.MNIST同时在3个分布式节点上运行的简单CNN实施DistributedDataParallel培训。我想将数据集划分为3个不重叠的子集(A,B,C),每个子集应包含20000张图像。各个子集应进一步分为培训和测试分区,即0.7%的培训和0.3%的测试。我计划将每个子集分别提供给每个分布式节点,以便它们可以以DistributedDataParallel方式进行训练和测试。

如下所示的基本代码,从torchvision.datasets.MNIST下载MNIST数据集,然后使用torch.utils.data.distributed.DistributedSampler和torch.utils.data.DataLoader创建数据批次,以便在单个数据上进行训练和测试节点。


# TRAINING DATA

train_dataset = datasets.MNIST('data', train=True, download=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))]))
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=3, rank=dist.get_rank())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=False, num_workers=3, pin_memory=True, sampler=True)


# TESTING DATA

test_dataset = datasets.MNIST('data', train=False, download=False, transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]))
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=3, pin_memory=True)

我希望答案会创建train_dataset_a,train_dataset_b和train_dataset_c以及test_dataset_a,test_dataset_b和test_dataset_c。

0 个答案:

没有答案
相关问题