可迭代数据集在单个 epoch 后耗尽

时间:2021-06-19 14:45:23

标签: python nlp pytorch torchtext

我想在情感分析任务上训练 RNN,为此我使用了由 torchtext 提供的 IMDB 数据集,其中包含 50000 条电影评论,它是一个 Python 迭代器。我使用了 split=('train', 'test')

我首先使用 torchtext.vocab.Vocab 构建了一个词汇并对每个句子进行了标记化,然后进行了数字化。

为了将序列填充到相同的长度,我使用了 torch.nn.utils.rnn.pad_sequence,并且还使用了 collate_fnbatch_sampler。然后我使用 torch.utils.data.DataLoader.

加载了数据

RNN 网络的实现很好,但数据加载器在一个 epoch 后就耗尽了,如下图所示。

我是否遵循了正确的方法来加载这个可迭代数据集?以及为什么数据加载器在一个 epoch 后就耗尽了,我该如何解决这个问题。

如果您想查看我的实现,请参阅共享的 colab 笔记本。

附注。我正在关注来自 github 的 Torchtext 的官方 changelog

您可以找到我的实现 here

Dataloader exhausted after a single epoch

1 个答案:

答案 0 :(得分:0)

解决方案是使用 torchtext.data.functional.to_map_style_dataset(iter_data) (official doc) 将您的可迭代样式数据集转换为地图样式数据集。

像这样:

from torchtext.data.functional import to_map_style_dataset
train_iter = IMDB(split='train')
train_dataset = to_map_style_dataset(train_iter)  #Map-style dataset

然后制作一个数据加载器。

from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=64, collate_fn=collate_fn)

为什么会这样?

我用上面例子的命名约定来解释。

传递给 train_iterDataloader 是一个 Iterable 风格的数据集,这意味着它没有实现 __getitem__。它只有 __iter____next__ dunders - 这使它成为可迭代的。

因此,如果我将一个可迭代对象传递给 Dataloader,数据加载器会在 StopIteration 异常发生后停止 - 这将被可迭代样式数据集的 __next__ dunder 抛出({{ 1}} 在这种情况下)当数据集(可迭代对象)耗尽时。

所以我们使用了 train_iter 函数将 Iterable-style 转换为 map-style 数据集。它通过实现 to_map_style_dataset dunder 来实现,因此 __getitem__ 默认使用索引从数据集中获取项目。

做同样事情的另一种可能方式也是

如果我使用可迭代样式的数据集 - 我需要在每个时期创建 Dataloader 对象。因此,在每个 epoch 之后,新的 dataloader 对象将在 for 循环中从头开始运行。

为了更好地理解 Pytorch 中 Iterable-style 和 Map-style 数据集的差异和用例,请参阅此 https://yizhepku.github.io/2020/12/26/dataloader.html