我想在情感分析任务上训练 RNN,为此我使用了由 torchtext 提供的 IMDB 数据集,其中包含 50000 条电影评论,它是一个 Python 迭代器。我使用了 split=('train', 'test')
。
我首先使用 torchtext.vocab.Vocab
构建了一个词汇并对每个句子进行了标记化,然后进行了数字化。
为了将序列填充到相同的长度,我使用了 torch.nn.utils.rnn.pad_sequence
,并且还使用了 collate_fn
和 batch_sampler
。然后我使用 torch.utils.data.DataLoader
.
RNN 网络的实现很好,但数据加载器在一个 epoch 后就耗尽了,如下图所示。
我是否遵循了正确的方法来加载这个可迭代数据集?以及为什么数据加载器在一个 epoch 后就耗尽了,我该如何解决这个问题。
如果您想查看我的实现,请参阅共享的 colab 笔记本。
附注。我正在关注来自 github 的 Torchtext 的官方 changelog
您可以找到我的实现 here
在附加的图像中,您可以看到数据加载器在一个时期后耗尽。
答案 0 :(得分:0)
问题是你的数据加载器是一个生成器,它在完全迭代后耗尽。一种解决方案是在每个 epoch 中初始化 dataloader。其次,是不要使用批量采样器。整理功能应该做你想做的事。
def collate_batch(batch):
batch.sort(key=lambda x: len(x[0]), reverse=True)
label_list, text_list, text_lengths = [], [], []
for _text, _label in batch:
label_list.append(_label)
processed_text = torch.tensor(_text)
text_list.append(processed_text)
text_lengths.append(len(processed_text))
return torch.tensor(label_list, dtype=torch.float32),
pad_sequence(text_list, padding_value=3.0),
torch.tensor(text_lengths, dtype=torch.int64)