DataLoader弄乱转换后的数据

时间:2019-09-13 13:39:22

标签: pytorch dataloader

我正在Pytorch中测试MNIST数据集,对X数据进行转换后,似乎DataLoader会将所有值都置于原始顺序之外,可能会弄乱训练步骤。

我的变换是将所有值除以255。请注意,变换本身不会改变位置,如第一个散点图所示。但是,在将数据传递到DataLoader并将其取回之后,它们就混乱了。如果我不进行任何转换,则一切都很好(未显示)。值的分布在 before after1 (除以255 / DataLoader之前)和 after2 (除以255 / DataLoader之后)之间是相同的)(也未显示),似乎只有顺序受到了影响。

import torch
from torchvision import datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

transform = transforms.ToTensor()

train = datasets.MNIST(root = '.', train = True, download = True, transform = transform)
test = datasets.MNIST(root = '.', train = False, download = True, transform = transform)

before = train.data[0]

train.data = train.data.float()/255
after1 = train.data[0]

train_loader = torch.utils.data.DataLoader(train, batch_size = 128)
test_loader = torch.utils.data.DataLoader(test, batch_size = 128)

fig, ax = plt.subplots(1, 2)
ax[0].scatter(range(len(before.view(-1))), before.view(-1))
ax[0].set_title('Before')
ax[1].scatter(range(len(after1.view(-1))), after1.view(-1))
ax[1].set_title('After1')

after2 = next(iter(train_loader))[0][0]

fig, ax = plt.subplots(1, 2)
ax[0].scatter(range(len(before.view(-1))), before.view(-1))
ax[0].set_title('Before')
ax[1].scatter(range(len(after2.view(-1))), after2.view(-1))
ax[1].set_title('After2')

fig, ax = plt.subplots(1, 3)
ax[0].imshow(before, cmap = 'gray')
ax[1].imshow(after1, cmap = 'gray')
ax[2].imshow(after2.view(28, 28), cmap = 'gray')

我知道这可能不是处理这些数据的最佳方法(transforms.Normalize应该可以解决它),但是我真的很想了解正在发生的事情。

谢谢!

2 个答案:

答案 0 :(得分:1)

所以...我posted this same question at Pytorch's GitHub page,他们回答了以下问题:

  

它与数据加载器无关。您搞砸了一个属性   特定的数据集对象,但是实际的__getitem__   对象的功能更多:   https://github.com/pytorch/vision/blob/6de158c473b83cf43344a0651d7c01128c7850e6/torchvision/datasets/mnist.py#L92

     

特别是,此行(mode='L')假设输入uint8。自从你   用float替换它,这是错误的。

然后我想首选的方法是在代码的开头准备数据集时应用转换。

答案 1 :(得分:0)

最初,我尚未测试您编写的代码。 改写了原件:

import torch
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, TensorDataset
import matplotlib.pyplot as plt

transform = transforms.ToTensor()

train = datasets.MNIST(root = '.', train = True, download = True, transform = transform)
test = datasets.MNIST(root = '.', train = False, download = True, transform = transform)

dl = DataLoader(train)

images = dl.dataset.data.float()/255
labels = dl.dataset.targets

train_ds = TensorDataset(images, labels)
train_loader = DataLoader(train_ds, batch_size=128)
# img, target = next(iter(train_loader))

before = train.data[0]
train.data = train.data.float()/255
after1 = train.data[0]

# train_loader = torch.utils.data.DataLoader(train, batch_size = 128)
test_loader = torch.utils.data.DataLoader(test, batch_size = 128)

fig, ax = plt.subplots(1, 2)
ax[0].scatter(range(len(before.view(-1))), before.view(-1))
ax[0].set_title('Before')
ax[1].scatter(range(len(after1.view(-1))), after1.view(-1))
ax[1].set_title('After1')

after2 = next(iter(train_loader))[0][0]

fig, ax = plt.subplots(1, 2)
ax[0].scatter(range(len(before.view(-1))), before.view(-1))
ax[0].set_title('Before')
ax[1].scatter(range(len(after2.view(-1))), after2.view(-1))
ax[1].set_title('After2')

fig, ax = plt.subplots(1, 3)
ax[0].imshow(before, cmap = 'gray')
ax[1].imshow(after1, cmap = 'gray')
ax[2].imshow(after2.view(28, 28), cmap = 'gray')