pytorch数据加载器多次迭代

时间:2017-12-08 12:41:13

标签: python pytorch

我使用iris-dataset训练一个带有pytorch的简单网络。

trainset = iris.Iris(train=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=150,
                                          shuffle=True, num_workers=2)

dataiter = iter(trainloader)

数据集本身只有150个数据点,并且由于批量大小为150,pytorch dataloader在整个数据集上迭代了一次。

现在我的问题是,通常有什么方法可以让pytorch的dataloader重复一遍数据集,如果它曾经迭代过一次吗?

thnaks

更新

让它跑步:) 刚刚创建了一个dataloader的子类,并实现了我自己的__next__()

4 个答案:

答案 0 :(得分:1)

最简单的选择是使用嵌套循环:

for i in range(10):
    for batch in trainloader:
        do_something(batch)

另一个选择是使用itertools.cycle,也许与itertools.take结合使用。

当然,使用批量大小等于整个数据集的DataLoader有点不寻常。您也不需要在trainloader上调用iter()。

答案 1 :(得分:1)

使用itertools.cycle有一个重要的缺点,因为它不会在每次迭代后重排数据:

  

当iterable耗尽时,返回已保存副本中的元素。

在某些情况下,这会对模型的性能产生负面影响。解决这个问题的方法是编写自己的循环生成器:

def cycle(iterable):
    while True:
        for x in iterable:
            yield x

您将用作:

dataiter = iter(cycle(trainloader))

答案 2 :(得分:0)

如果使用tqdm,最好的解决方案是:

from tqdm import tqdm
pbar = tqdm(itertools.chain(validation_loader,
    validation_loader,
    validation_loader,
    validation_loader)) # 4 times loop through
for batch_index, (x, y) in enumerate(pbar):
    ...

答案 3 :(得分:0)

补充先前的答案。为了在数据集之间进行比较,通常最好将步骤总数而不是历时总数用作超参数。那是因为迭代次数不应该依赖于数据集的大小,而应该依赖于它的复杂性。

我正在使用以下代码进行培训。这样可以确保数据加载器在每次重新启动时都重新随机整理数据。

# main training loop
    generator = iter(trainloader)
    for i in range(max_steps):

        try:
            # Samples the batch
            x, y = next(generator)
        except StopIteration:
            # restart the generator if the previous generator is exhausted.
            generator = iter(trainloader)
            x, y = next(generator)

我同意这不是最优雅的解决方案,但这使我不必依赖时代来训练。