我最近开始使用tutorial通过pytorch学习深度学习。
这些代码行有问题。
参数train=True
表示它将取出训练数据。
但培训50%需要多少数据?
我们如何指定用于训练的数据量。同样,无法理解batch_size
和num_workers
,这对加载数据数据意味着什么? batch_size
参数类似于深度学习中用于训练的参数吗?
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
答案 0 :(得分:1)
如果您之前未拆分数据,则培训师将使用整个培训文件夹。您可以通过拆分数据来指定训练量,请参阅:
from torchvision import datasets
# convert data to a normalized torch.FloatTensor
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# choose the training and test datasets
train_data = datasets.CIFAR10('data', train=True,
download=True, transform=transform)
test_data = datasets.CIFAR10('data', train=False,
download=True, transform=transform)
valid_size = 0.2
# obtain training indices that will be used for validation
num_train = len(train_data)
indices = list(range(num_train))
np.random.shuffle(indices)
split = int(np.floor(valid_size * num_train))
train_idx, valid_idx = indices[split:], indices[:split]
# define samplers for obtaining training and validation batches
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
# prepare data loaders (combine dataset and sampler)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
sampler=train_sampler, num_workers=num_workers)
valid_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
sampler=valid_sampler, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size,
num_workers=num_workers)```
批处理大小是您通过迭代(时期)捕获的文件数。例如,如果training_size为1000,而batch_size为10,则每个纪元将包含100次迭代。
工人数量用于预处理批次数据。更多的工作人员将消耗更多的内存使用量,并且工作人员将有助于加快输入和输出过程。 num_workers = 0表示将在需要时进行数据加载, num_workers> 0表示将使用您定义的工作者数对数据进行预处理。
答案 1 :(得分:0)
batch_size
是所需的批次大小(您提供的数据集中的数据组),num_workers
是处理这些批次的工作人员的数量,基本上是多处理工作人员。
但是50%的训练需要花费多少数据?
DataLoader无法为您提供任何方法来控制您希望提取的样本数量。您将必须使用对迭代器进行切片的典型方法。
最简单的操作(没有任何库)是在达到所需的样本数量后停止。
nsamples = 10000
for i, image, label in enumerate(train_loader):
if i > nsamples:
break
# Your training code here.
或者,您可以使用itertools.islice来获取前10k个样本。像这样。
for image, label in itertools.islice(train_loader, stop=10000):
# your training code here.
您可以参考此answer