我正在Pytorch中测试MNIST数据集,对X数据进行转换后,似乎DataLoader会将所有值都置于原始顺序之外,可能会弄乱训练步骤。
我的变换是将所有值除以255。请注意,变换本身不会改变位置,如第一个散点图所示。但是,在将数据传递到DataLoader并将其取回之后,它们就混乱了。如果我不进行任何转换,则一切都很好(未显示)。值的分布在 before , after1 (除以255 / DataLoader之前)和 after2 (除以255 / DataLoader之后)之间是相同的)(也未显示),似乎只有顺序受到了影响。
import torch
from torchvision import datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
transform = transforms.ToTensor()
train = datasets.MNIST(root = '.', train = True, download = True, transform = transform)
test = datasets.MNIST(root = '.', train = False, download = True, transform = transform)
before = train.data[0]
train.data = train.data.float()/255
after1 = train.data[0]
train_loader = torch.utils.data.DataLoader(train, batch_size = 128)
test_loader = torch.utils.data.DataLoader(test, batch_size = 128)
fig, ax = plt.subplots(1, 2)
ax[0].scatter(range(len(before.view(-1))), before.view(-1))
ax[0].set_title('Before')
ax[1].scatter(range(len(after1.view(-1))), after1.view(-1))
ax[1].set_title('After1')
after2 = next(iter(train_loader))[0][0]
fig, ax = plt.subplots(1, 2)
ax[0].scatter(range(len(before.view(-1))), before.view(-1))
ax[0].set_title('Before')
ax[1].scatter(range(len(after2.view(-1))), after2.view(-1))
ax[1].set_title('After2')
fig, ax = plt.subplots(1, 3)
ax[0].imshow(before, cmap = 'gray')
ax[1].imshow(after1, cmap = 'gray')
ax[2].imshow(after2.view(28, 28), cmap = 'gray')
我知道这可能不是处理这些数据的最佳方法(transforms.Normalize
应该可以解决它),但是我真的很想了解正在发生的事情。
谢谢!
答案 0 :(得分:1)
所以...我posted this same question at Pytorch's GitHub page,他们回答了以下问题:
它与数据加载器无关。您搞砸了一个属性 特定的数据集对象,但是实际的
__getitem__
对象的功能更多: https://github.com/pytorch/vision/blob/6de158c473b83cf43344a0651d7c01128c7850e6/torchvision/datasets/mnist.py#L92特别是,此行(
mode='L'
)假设输入uint8
。自从你 用float替换它,这是错误的。
然后我想首选的方法是在代码的开头准备数据集时应用转换。
答案 1 :(得分:0)
最初,我尚未测试您编写的代码。 改写了原件:
import torch
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, TensorDataset
import matplotlib.pyplot as plt
transform = transforms.ToTensor()
train = datasets.MNIST(root = '.', train = True, download = True, transform = transform)
test = datasets.MNIST(root = '.', train = False, download = True, transform = transform)
dl = DataLoader(train)
images = dl.dataset.data.float()/255
labels = dl.dataset.targets
train_ds = TensorDataset(images, labels)
train_loader = DataLoader(train_ds, batch_size=128)
# img, target = next(iter(train_loader))
before = train.data[0]
train.data = train.data.float()/255
after1 = train.data[0]
# train_loader = torch.utils.data.DataLoader(train, batch_size = 128)
test_loader = torch.utils.data.DataLoader(test, batch_size = 128)
fig, ax = plt.subplots(1, 2)
ax[0].scatter(range(len(before.view(-1))), before.view(-1))
ax[0].set_title('Before')
ax[1].scatter(range(len(after1.view(-1))), after1.view(-1))
ax[1].set_title('After1')
after2 = next(iter(train_loader))[0][0]
fig, ax = plt.subplots(1, 2)
ax[0].scatter(range(len(before.view(-1))), before.view(-1))
ax[0].set_title('Before')
ax[1].scatter(range(len(after2.view(-1))), after2.view(-1))
ax[1].set_title('After2')
fig, ax = plt.subplots(1, 3)
ax[0].imshow(before, cmap = 'gray')
ax[1].imshow(after1, cmap = 'gray')
ax[2].imshow(after2.view(28, 28), cmap = 'gray')