在Pytorch中,使用Torchvision MNIST数据集时,我们可以获得如下数字:
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, TensorDataset
tsfm = transforms.Compose([transforms.Resize((16, 16)),
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
mnist_ds = torchvision.datasets.MNIST(root='../../../_data/mnist',train=True,
download=True, transform=tsfm)
digit_12 = mnist_ds[12]
尽管可以对大多数数据集进行切片,但我们不能对其中一个切片:
digit_12_to_14 = mnist_ds[12:15]
将返回
ValueError: Too many dimensions: 3 > 2.
这是由于getItem()中的Image.fromarray()
是否可以在不使用Dataloader的情况下使用MNIST数据集?怎么样?
PS:我想避免使用Dataloader的原因是,一次向GPU发送一批数据会减慢训练速度。我更喜欢一次将所有数据发送到GPU。为此,我需要访问整个TRANSFORMED数据集。
答案 0 :(得分:1)
Dataset
界面仅要求
所有子类均应覆盖
__len__
(提供数据集的大小)和__getitem__
(支持从0
到len(self)
的整数索引)。
显然没有提到切片-其他数据集的切片行为是一项额外功能。如果要一次获取全部数据,则可以查找implementation并仅使用在mnist.data
末尾定义的mnist.targets
和__init__
张量。>
如果要转换数据,可以使用
data = [mnist_ds[i] for i in range(len(mnist_ds))]
xs = torch.stack([d[0] for d in data], dim=0)
ys = torch.stack([d[1] for d in data], dim=0)
或一次转换所有mnist.data
张量(尽管不适用于torchvision.transform
转换)。
答案 1 :(得分:1)
I found 2 solutions so far to convert torchvision MNIST dataset to tensors. The first one is derived from Fábio Perez comment :
print("\nFirst...")
st = time()
x_all_ts = torch.tensor([mnist_ds[i][0].numpy() for i in range(0, len(mnist_ds))])
t_all_ts = mnist_ds.train_labels
print(f"{time()-st} images:{x_all_ts.size()} targets:{t_all_ts.size()} ")
print("\nSecond...")
st = time()
mnist_dl = DataLoader(dataset=mnist_ds, batch_size=len(mnist_ds))
x_all_ts2, t_all_ts2 = list(mnist_dl)[0]
print(f"{time()-st} images:{x_all_ts2.size()} targets:{t_all_ts2.size()} ")
First...
19.573785066604614 images:torch.Size([60000, 1, 16, 16]) targets:torch.Size([60000])
Second...
16.826476573944092 images:torch.Size([60000, 1, 16, 16]) targets:torch.Size([60000])
Please let me know if you find better ones.
答案 2 :(得分:0)
您可以使用 torch.utils.data.Subset()
获取基于索引的火炬切片 Dataset
例如:
import torch.utils.data as data_utils
indices = torch.arange(12,15)
mnist_12to14 = data_utils.Subset(tr, indices)