Pytorch-无法切片Torchvision MNIST数据集

时间:2019-01-18 10:13:34

标签: dataset slice pytorch

在Pytorch中,使用Torchvision MNIST数据集时,我们可以获得如下数字:

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, TensorDataset

tsfm = transforms.Compose([transforms.Resize((16, 16)),
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))])

mnist_ds = torchvision.datasets.MNIST(root='../../../_data/mnist',train=True,
                                download=True, transform=tsfm)


digit_12 = mnist_ds[12]

尽管可以对大多数数据集进行切片,但我们不能对其中一个切片:

digit_12_to_14 = mnist_ds[12:15]

将返回

ValueError: Too many dimensions: 3 > 2.

这是由于getItem()中的Image.fromarray()

是否可以在不使用Dataloader的情况下使用MNIST数据集?怎么样?

PS:我想避免使用Dataloader的原因是,一次向GPU发送一批数据会减慢训练速度。我更喜欢一次将所有数据发送到GPU。为此,我需要访问整个TRANSFORMED数据集。

3 个答案:

答案 0 :(得分:1)

Dataset界面仅要求

  

所有子类均应覆盖__len__(提供数据集的大小)和__getitem__(支持从0len(self)的整数索引)。

显然没有提到切片-其他数据集的切片行为是一项额外功能。如果要一次获取全部数据,则可以查找implementation并仅使用在mnist.data末尾定义的mnist.targets__init__张量。

如果要转换数据,可以使用

data = [mnist_ds[i] for i in range(len(mnist_ds))]
xs = torch.stack([d[0] for d in data], dim=0)
ys = torch.stack([d[1] for d in data], dim=0)

或一次转换所有mnist.data张量(尽管不适用于torchvision.transform转换)。

答案 1 :(得分:1)

I found 2 solutions so far to convert torchvision MNIST dataset to tensors. The first one is derived from Fábio Perez comment :

print("\nFirst...")
st = time()
x_all_ts = torch.tensor([mnist_ds[i][0].numpy() for i in range(0, len(mnist_ds))])
t_all_ts = mnist_ds.train_labels
print(f"{time()-st}   images:{x_all_ts.size()}  targets:{t_all_ts.size()} ")

print("\nSecond...")
st = time()
mnist_dl = DataLoader(dataset=mnist_ds, batch_size=len(mnist_ds))
x_all_ts2, t_all_ts2 = list(mnist_dl)[0]
print(f"{time()-st}   images:{x_all_ts2.size()}  targets:{t_all_ts2.size()} ")


First...
19.573785066604614   images:torch.Size([60000, 1, 16, 16])  targets:torch.Size([60000]) 
Second...
16.826476573944092   images:torch.Size([60000, 1, 16, 16])  targets:torch.Size([60000]) 

Please let me know if you find better ones.

答案 2 :(得分:0)

您可以使用 torch.utils.data.Subset() 获取基于索引的火炬切片 Dataset 例如:

import torch.utils.data as data_utils

indices = torch.arange(12,15)
mnist_12to14 = data_utils.Subset(tr, indices)