我使用iris-dataset训练一个带有pytorch的简单网络。
trainset = iris.Iris(train=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=150,
shuffle=True, num_workers=2)
dataiter = iter(trainloader)
数据集本身只有150个数据点,并且由于批量大小为150,pytorch dataloader在整个数据集上迭代了一次。
现在我的问题是,通常有什么方法可以让pytorch的dataloader重复一遍数据集,如果它曾经迭代过一次吗?
thnaks
更新
让它跑步:)
刚刚创建了一个dataloader的子类,并实现了我自己的__next__()
答案 0 :(得分:1)
最简单的选择是使用嵌套循环:
for i in range(10):
for batch in trainloader:
do_something(batch)
另一个选择是使用itertools.cycle,也许与itertools.take结合使用。
当然,使用批量大小等于整个数据集的DataLoader有点不寻常。您也不需要在trainloader上调用iter()。
答案 1 :(得分:1)
使用itertools.cycle有一个重要的缺点,因为它不会在每次迭代后重排数据:
当iterable耗尽时,返回已保存副本中的元素。
在某些情况下,这会对模型的性能产生负面影响。解决这个问题的方法是编写自己的循环生成器:
def cycle(iterable):
while True:
for x in iterable:
yield x
您将用作:
dataiter = iter(cycle(trainloader))
答案 2 :(得分:0)
如果使用tqdm,最好的解决方案是:
from tqdm import tqdm
pbar = tqdm(itertools.chain(validation_loader,
validation_loader,
validation_loader,
validation_loader)) # 4 times loop through
for batch_index, (x, y) in enumerate(pbar):
...
答案 3 :(得分:0)
补充先前的答案。为了在数据集之间进行比较,通常最好将步骤总数而不是历时总数用作超参数。那是因为迭代次数不应该依赖于数据集的大小,而应该依赖于它的复杂性。
我正在使用以下代码进行培训。这样可以确保数据加载器在每次重新启动时都重新随机整理数据。
# main training loop
generator = iter(trainloader)
for i in range(max_steps):
try:
# Samples the batch
x, y = next(generator)
except StopIteration:
# restart the generator if the previous generator is exhausted.
generator = iter(trainloader)
x, y = next(generator)
我同意这不是最优雅的解决方案,但这使我不必依赖时代来训练。