Pytorch DataLoader-处理单个批次

时间:2018-11-07 07:16:13

标签: python lstm pytorch

我正在为词性标记构建LSTM,目前正在使用this LSTM作为参考,以了解如何使用Pytorch DataLoader。但是,我对如何处理单个批次感到困惑(Pytorch官方文档对我没有任何帮助)。更具体地说,我需要填充我的句子,以使同一批次中的所有 句子具有相同的长度(不同的批次可以具有不同的长度)。

我的代码当前如下所示:

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE)

# other stuff...

for epoch in range(0, EPOCH_COUNT):
    for iter, sentence in enumerate(train_loader):
        model.zero_grad()
        model.hidden = model.init_hidden()
        sentence_in = prep_sentence(sentence, word_to_ix)
        targets = prep_tags(tags, tag_to_ix)
        # do forward pass and then backprop

由于我正在枚举train_loader并遍历每个句子,因此我不清楚如何/在什么时候可以获得单批的最大句子长度,然后运行我的add_padding()函数。

0 个答案:

没有答案