使用Pytorch加载用于神经机器翻译的大文本文件

时间:2018-03-17 15:37:45

标签: python pytorch

在PyTorch中,我编写了一个数据集加载类,用于加载2个文本文件作为源和目标,用于神经机器翻译。每个文件有93577946行,每个文件在硬盘上分配8GB内存。

该课程如下:

class LoadUniModal(Dataset):
    sources = []
    targets = []
    maxlen = 0
    lengths = []

    def __init__(self, src, trg, src_vocab, trg_vocab):
        self.src_vocab = src_vocab
        self.trg_vocab = trg_vocab

        with codecs.open(src, encoding="utf-8") as f:
            for line in f:
                tokens = line.replace("\n", "").split()
                self.maxlen = max(self.maxlen, len(tokens))
                self.sources.append(tokens)
        with codecs.open(trg, encoding="utf-8") as f:
            for line in f:
                tokens = line.replace("\n", "").split()
                self.maxlen = max(self.maxlen, len(tokens))
                self.targets.append(tokens)
                self.lengths.append(len(tokens)+2)

    # Overrride to give PyTorch access to any image on the dataset
    def __getitem__(self, index):

        # Source sentence processing
        tokens = self.sources[index]
        ntokens = [self.src_vocab['<START>']]
        for a in range(self.maxlen):
            if a <= (len(tokens) - 1):
                if tokens[a] in self.src_vocab.keys():
                    ntokens.append(self.src_vocab[tokens[a]])
                else:
                    ntokens.append(self.src_vocab['<UNK>'])
            elif a == len(tokens):
                ntokens.append(self.src_vocab['<END>'])
            elif a > len(tokens):
                ntokens.append(self.src_vocab['<PAD>'])

        source = torch.from_numpy(np.asarray(ntokens)).long()

        # Target sentence processing
        tokens = self.targets[index]
                ntokens = [self.trg_vocab['<START>']]
                for a in range(self.maxlen):
                        if a <= (len(tokens) - 1):
                                if tokens[a] in self.trg_vocab.keys():
                                        ntokens.append(self.trg_vocab[tokens[a]])
                                else:
                                        ntokens.append(self.trg_vocab['<UNK>'])
                        elif a == len(tokens):
                                ntokens.append(self.trg_vocab['<END>'])
                        elif a > len(tokens):
                                ntokens.append(self.trg_vocab['<PAD>'])

                target = torch.from_numpy(np.asarray(ntokens)).long()

        length = self.lengths[index]

        return [0], source, target, length

    def __len__(self):
        return len(self.sources)

我使用该类来加载数据集,如下所示:

def load_text_train_data(train_dir, src_vocab, trg_vocab, lang_pair, batch_size):

        tpl = ast.literal_eval(lang_pair)
        slang = tpl[1]
        tlang = tpl[2]

        strain_file = os.path.join(train_dir, "train"+slang)
        ttrain_file = os.path.join(train_dir, "train"+tlang)

        data_iter = LoadUniModal(strain_file, ttrain_file, src_vocab, trg_vocab)
        data_iter = DataLoader(data_iter, batch_size=batch_size)

        return data_iter

当我尝试加载数据时,出现内存错误。

如何在没有内存问题的情况下加载数据?

谢谢,

1 个答案:

答案 0 :(得分:1)

除非您将整个数据一次加载到内存中,否则不应该给出错误。我想给你的一个建议是:不要将所有句子填入最大长度。在机器翻译数据中,一般来说,句子长度变化很大。

此外,您可以尝试更小的小批量x(例如,32,64),这是您的记忆所能承受的。仅填充当前小批量的元素并移至cuda张量,然后将其传递给您的模型。希望它能解决你的问题。