可迭代数据集在一个时期后耗尽 [来自 torchtext 的 IMBD 数据集]

时间:2021-06-19 19:24:01

标签: python iterator pytorch torchtext

我想在情感分析任务上训练 RNN,为此我使用了由 torchtext 提供的 IMDB 数据集,其中包含 50000 条电影评论,它是一个 Python 迭代器。我使用了 split=('train', 'test')

我首先使用 torchtext.vocab.Vocab 构建了一个词汇并对每个句子进行了标记化,然后进行了数字化。

为了将序列填充到相同的长度,我使用了 torch.nn.utils.rnn.pad_sequence,并且还使用了 collate_fnbatch_sampler。然后我使用 torch.utils.data.DataLoader.

加载了数据

RNN 网络的实现很好,但数据加载器在一个 epoch 后就耗尽了,如下图所示。

我是否遵循了正确的方法来加载这个可迭代数据集?以及为什么数据加载器在一个 epoch 后就耗尽了,我该如何解决这个问题。

如果您想查看我的实现,请参阅共享的 colab 笔记本。

附注。我正在关注来自 github 的 Torchtext 的官方 changelog

您可以找到我的实现 here

Dataloader exhausted after a single epoch

在附加的图像中,您可以看到数据加载器在一个时期后耗尽。

1 个答案:

答案 0 :(得分:0)

问题是你的数据加载器是一个生成器,它在完全迭代后耗尽。一种解决方案是在每个 epoch 中初始化 dataloader。其次,是不要使用批量采样器。整理功能应该做你想做的事。

def collate_batch(batch):
    batch.sort(key=lambda x: len(x[0]), reverse=True)
    label_list, text_list, text_lengths = [], [], []
    
    for _text, _label in batch:
        label_list.append(_label)
        processed_text = torch.tensor(_text)
        text_list.append(processed_text)
        text_lengths.append(len(processed_text))

    return torch.tensor(label_list, dtype=torch.float32),
           pad_sequence(text_list, padding_value=3.0), 
           torch.tensor(text_lengths, dtype=torch.int64)