我有一个问题,如何从pytorch数据加载器中获取批处理迭代的总数?
以下是常见的培训代码
for i, batch in enumerate(dataloader):
然后,有什么方法可以获取“ for循环”的迭代总数吗?
在我的NLP问题中,迭代总数不同于int(n_train_samples / batch_size)...
例如,如果我仅截断训练数据10,000个样本并将批次大小设置为1024,那么在我的NLP问题中会发生363次迭代。
我想知道如何在“ for-loop”中获得总迭代次数。
谢谢。
答案 0 :(得分:7)
len(dataloader)
返回批次总数。它取决于数据集的__len__
函数,因此请确保已正确设置。
答案 1 :(得分:0)
创建数据加载器时还有一个附加参数。它称为drop_last
。
如果drop_last=True
,则长度为number_of_training_examples // batch_size
。
如果drop_last=False
可能是number_of_training_examples // batch_size +1
。
BS=128
ds_train = torchvision.datasets.CIFAR10('/data/cifar10', download=True, train=True, transform=t_train)
dl_train = DataLoader( ds_train, batch_size=BS, drop_last=True, shuffle=True)
对于预定义的数据集,您可能会看到许多示例,例如:
# number of examples
len(dl_train.dataset)
数据加载器中的正确批数始终为:
# number of batches
len(dl_train)