PyTorch数据集:将整个数据集转换为NumPy

时间:2019-02-27 03:38:16

标签: python numpy pytorch torchvision

我正在尝试将Torchvision MNIST训练和数据集转换为NumPy数组,但是找不到实际执行转换的文档。

我的目标是获取整个数据集并将其转换为单个NumPy数组,最好不要遍历整个数据集。

我看过How do I turn a Pytorch Dataloader into a numpy array to display image data with matplotlib?,但没有解决我的问题。

所以我的问题是,利用torch.utils.data.DataLoader,我将如何将数据集(训练/测试)转换为两个NumPy数组,以便所有示例都存在?

注意:我现在将批处理大小保留为默认值1;我可以将其设置为用于火车的60,000,用于测试的10,000,但是我宁愿不要使用这种魔术数字。

谢谢。

2 个答案:

答案 0 :(得分:4)

此任务无需使用from torchvision import datasets, transforms train_set = datasets.MNIST('./data', train=True, download=True) test_set = datasets.MNIST('./data', train=False, download=True) train_set_array = train_set.data.numpy() test_set_array = test_set.data.numpy()

{{1}}

请注意,在这种情况下,目标被排除在外。

答案 1 :(得分:1)

如果我对您的理解正确,那么您希望获得整个MNIST图像的训练数据集(总共60000张图像,每个图像的大小为1x28x28数组,颜色通道为1)作为大小为numpy的数组(60000、1、28 ,28)?

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Transform to normalized Tensors 
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])

train_dataset = datasets.MNIST('./MNIST/', train=True, transform=transform, download=True)
# test_dataset = datasets.MNIST('./MNIST/', train=False, transform=transform, download=True)


train_loader = DataLoader(train_dataset, batch_size=len(train_dataset))
# test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))

train_dataset_array = next(iter(train_loader))[0].numpy()
# test_dataset_array = next(iter(test_loader))[0].numpy()

这是结果:

>>> train_dataset_array

array([[[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         ...,
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296]]],


       [[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         ...,
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296]]],


       [[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         ...,
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296]]],


       ...,


       [[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         ...,
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296]]],


       [[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         ...,
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296]]],


       [[[-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         ...,
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296],
         [-0.42421296, -0.42421296, -0.42421296, ..., -0.42421296,
          -0.42421296, -0.42421296]]]], dtype=float32)