在PyTorch中,我编写了一个数据集加载类,用于加载2个文本文件作为源和目标,用于神经机器翻译。每个文件有93577946行,每个文件在硬盘上分配8GB内存。
该课程如下:
class LoadUniModal(Dataset):
sources = []
targets = []
maxlen = 0
lengths = []
def __init__(self, src, trg, src_vocab, trg_vocab):
self.src_vocab = src_vocab
self.trg_vocab = trg_vocab
with codecs.open(src, encoding="utf-8") as f:
for line in f:
tokens = line.replace("\n", "").split()
self.maxlen = max(self.maxlen, len(tokens))
self.sources.append(tokens)
with codecs.open(trg, encoding="utf-8") as f:
for line in f:
tokens = line.replace("\n", "").split()
self.maxlen = max(self.maxlen, len(tokens))
self.targets.append(tokens)
self.lengths.append(len(tokens)+2)
# Overrride to give PyTorch access to any image on the dataset
def __getitem__(self, index):
# Source sentence processing
tokens = self.sources[index]
ntokens = [self.src_vocab['<START>']]
for a in range(self.maxlen):
if a <= (len(tokens) - 1):
if tokens[a] in self.src_vocab.keys():
ntokens.append(self.src_vocab[tokens[a]])
else:
ntokens.append(self.src_vocab['<UNK>'])
elif a == len(tokens):
ntokens.append(self.src_vocab['<END>'])
elif a > len(tokens):
ntokens.append(self.src_vocab['<PAD>'])
source = torch.from_numpy(np.asarray(ntokens)).long()
# Target sentence processing
tokens = self.targets[index]
ntokens = [self.trg_vocab['<START>']]
for a in range(self.maxlen):
if a <= (len(tokens) - 1):
if tokens[a] in self.trg_vocab.keys():
ntokens.append(self.trg_vocab[tokens[a]])
else:
ntokens.append(self.trg_vocab['<UNK>'])
elif a == len(tokens):
ntokens.append(self.trg_vocab['<END>'])
elif a > len(tokens):
ntokens.append(self.trg_vocab['<PAD>'])
target = torch.from_numpy(np.asarray(ntokens)).long()
length = self.lengths[index]
return [0], source, target, length
def __len__(self):
return len(self.sources)
我使用该类来加载数据集,如下所示:
def load_text_train_data(train_dir, src_vocab, trg_vocab, lang_pair, batch_size):
tpl = ast.literal_eval(lang_pair)
slang = tpl[1]
tlang = tpl[2]
strain_file = os.path.join(train_dir, "train"+slang)
ttrain_file = os.path.join(train_dir, "train"+tlang)
data_iter = LoadUniModal(strain_file, ttrain_file, src_vocab, trg_vocab)
data_iter = DataLoader(data_iter, batch_size=batch_size)
return data_iter
当我尝试加载数据时,出现内存错误。
如何在没有内存问题的情况下加载数据?
谢谢,
答案 0 :(得分:1)
除非您将整个数据一次加载到内存中,否则不应该给出错误。我想给你的一个建议是:不要将所有句子填入最大长度。在机器翻译数据中,一般来说,句子长度变化很大。
此外,您可以尝试更小的小批量x
(例如,32,64),这是您的记忆所能承受的。仅填充当前小批量的元素并移至cuda张量,然后将其传递给您的模型。希望它能解决你的问题。