我想在情感分析任务上训练 RNN,为此我使用了由 torchtext 提供的 IMDB 数据集,其中包含 50000 条电影评论,它是一个 Python 迭代器。我使用了 split=('train', 'test')
。
我首先使用 torchtext.vocab.Vocab
构建了一个词汇并对每个句子进行了标记化,然后进行了数字化。
为了将序列填充到相同的长度,我使用了 torch.nn.utils.rnn.pad_sequence
,并且还使用了 collate_fn
和 batch_sampler
。然后我使用 torch.utils.data.DataLoader
.
RNN 网络的实现很好,但数据加载器在一个 epoch 后就耗尽了,如下图所示。
我是否遵循了正确的方法来加载这个可迭代数据集?以及为什么数据加载器在一个 epoch 后就耗尽了,我该如何解决这个问题。
如果您想查看我的实现,请参阅共享的 colab 笔记本。
附注。我正在关注来自 github 的 Torchtext 的官方 changelog
您可以找到我的实现 here
答案 0 :(得分:0)
解决方案是使用 torchtext.data.functional.to_map_style_dataset(iter_data)
(official doc) 将您的可迭代样式数据集转换为地图样式数据集。
像这样:
from torchtext.data.functional import to_map_style_dataset
train_iter = IMDB(split='train')
train_dataset = to_map_style_dataset(train_iter) #Map-style dataset
然后制作一个数据加载器。
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset, batch_size=64, collate_fn=collate_fn)
我用上面例子的命名约定来解释。
传递给 train_iter
的 Dataloader
是一个 Iterable 风格的数据集,这意味着它没有实现 __getitem__
。它只有 __iter__
和 __next__
dunders - 这使它成为可迭代的。
因此,如果我将一个可迭代对象传递给 Dataloader
,数据加载器会在 StopIteration
异常发生后停止 - 这将被可迭代样式数据集的 __next__
dunder 抛出({{ 1}} 在这种情况下)当数据集(可迭代对象)耗尽时。
所以我们使用了 train_iter
函数将 Iterable-style 转换为 map-style 数据集。它通过实现 to_map_style_dataset
dunder 来实现,因此 __getitem__
默认使用索引从数据集中获取项目。
如果我使用可迭代样式的数据集 - 我需要在每个时期创建 Dataloader
对象。因此,在每个 epoch 之后,新的 dataloader 对象将在 for 循环中从头开始运行。
为了更好地理解 Pytorch 中 Iterable-style 和 Map-style 数据集的差异和用例,请参阅此 https://yizhepku.github.io/2020/12/26/dataloader.html