我做了一个实验,但没有得到预期的结果。
第一部分,我正在使用
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=False, num_workers=0)
在训练模型之前,我将trainloader.dataset.targets
保存到变量a
,并将trainloader.dataset.data
保存到变量b
。然后,我使用trainloader
训练模型。
培训结束后,我将trainloader.dataset.targets
保存到变量c
,并将trainloader.dataset.data
保存到变量d
。最后,我检查了a == c
和b == d
,它们都给出了True
,这是可以预期的,因为DataLoader
的shuffle参数是False
。
对于第二部分,我正在使用
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=True, num_workers=0)
在训练模型之前,我将trainloader.dataset.targets
保存到变量e
,并将trainloader.dataset.data
保存到变量f
。然后,我使用trainloader
训练模型。训练结束后,我将trainloader.dataset.targets
保存到变量g
,并将trainloader.dataset.data
保存到变量h
。我期望e == g
以来的f == h
和False
都是shuffle=True
,但是他们又给True
。 DataLoader
类的定义中我缺少什么?
答案 0 :(得分:1)
我相信直接存储在trainloader.dataset.data或.target中的数据不会被改组,仅当将DataLoader称为生成器或迭代器时才对数据进行改组
您可以通过多次执行next(iter(trainloader))来进行检查,而无需混洗和混洗,它们应该给出不同的结果
import torch
import torchvision
transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])
MNIST_dataset = torchvision.datasets.MNIST('~/Desktop/intern/',download = True, train = False,
transform = transform)
dataLoader = torch.utils.data.DataLoader(MNIST_dataset,
batch_size = 128,
shuffle = False,
num_workers = 10)
target = dataLoader.dataset.targets
MNIST_dataset = torchvision.datasets.MNIST('~/Desktop/intern/',download = True, train = False,
transform = transform)
dataLoader_shuffled= torch.utils.data.DataLoader(MNIST_dataset,
batch_size = 128,
shuffle = True,
num_workers = 10)
target_shuffled = dataLoader_shuffled.dataset.targets
print(target == target_shuffled)
_, target = next(iter(dataLoader));
_, target_shuffled = next(iter(dataLoader_shuffled))
print(target == target_shuffled)
这将给出:
tensor([True, True, True, ..., True, True, True])
tensor([False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, True,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, True, False, False, False, False, False,
False, True, False, False, False, False, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, True, True, False, False, False, False,
False, False, False, False, False, False, False, False, False, False,
False, False, False, False, False, True, False, False, True, False,
False, False, False, False, False, False, False, False])
但是,存储在数据和目标中的数据和标签是固定列表,并且由于您尝试直接访问它,因此它们不会被重新排列。
答案 1 :(得分:0)
在使用Dataset类加载数据时,我遇到了类似的问题。我停止使用Dataset类加载数据,而是使用下面的代码对我来说很好
X = torch.from_numpy(X)
y = torch.from_numpy(y)
train_data = torch.utils.data.TensorDataset(X, y)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True)
其中X和y是csv文件中的numpy数组。