如何保存torchtext数据集?

时间:2018-11-21 23:44:12

标签: python pickle pytorch torch torchtext

我正在处理文本,并使用torchtext.data.Dataset。 创建数据集需要花费大量时间。 对于仅运行程序,这仍然可以接受。但是我想调试神经网络的割炬代码。并且,如果以调试模式启动python,则创建数据集大约需要20分钟(!!)。那只是为了获得一个工作环境,在这里我可以逐步调试神经网络代码。

我想保存数据集,例如用pickle保存。此示例代码摘自here,但我删除了此示例不需要的所有内容:

from torchtext import data
from fastai.nlp import *

PATH = 'data/aclImdb/'

TRN_PATH = 'train/all/'
VAL_PATH = 'test/all/'
TRN = f'{PATH}{TRN_PATH}'
VAL = f'{PATH}{VAL_PATH}'

TEXT = data.Field(lower=True, tokenize="spacy")

bs = 64;
bptt = 70

FILES = dict(train=TRN_PATH, validation=VAL_PATH, test=VAL_PATH)
md = LanguageModelData.from_text_files(PATH, TEXT, **FILES, bs=bs, bptt=bptt, min_freq=10)

with open("md.pkl", "wb") as file:
    pickle.dump(md, file)

要运行代码,您需要aclImdb数据集,可以从here下载。将其解压缩到此代码段旁边的data/文件夹中。代码在使用pickle的最后一行产生错误:

Traceback (most recent call last):
  File "/home/lhk/programming/fastai_sandbox/lesson4-imdb2.py", line 27, in <module>
    pickle.dump(md, file)
TypeError: 'generator' object is not callable

fastai的样品经常使用dill而不是泡菜。但这对我也不起作用。

4 个答案:

答案 0 :(得分:1)

我为自己想出了以下功能:

import dill
from pathlib import Path

import torch
from torchtext.data import Dataset

def save_dataset(dataset, path):
    if not isinstance(path, Path):
        path = Path(path)
    path.mkdir(parents=True, exist_ok=True)
    torch.save(dataset.examples, path/"examples.pkl", pickle_module=dill)
    torch.save(dataset.fields, path/"fields.pkl", pickle_module=dill)

def load_dataset(path):
    if not isinstance(path, Path):
        path = Path(path)
    examples = torch.load(path/"examples.pkl", pickle_module=dill)
    fields = torch.load(path/"fields.pkl", pickle_module=dill)
    return Dataset(examples, fields)

并非实际对象可能有所不同,例如,如果保存TabularDataset,则load_dataset返回类Dataset的实例。这不太可能影响数据管道,但可能需要额外的测试努力。 对于自定义令牌生成器,它也应该可序列化(例如,没有lambda函数等)。

答案 1 :(得分:0)

您可以使用莳萝代替泡菜。这个对我有用。 您可以保存一个torchtext字段,例如

TEXT = data.Field(sequential=True, tokenize=tokenizer, lower=True,fix_length=200,batch_first=True)
with open("model/TEXT.Field","wb")as f:
     dill.dump(TEXT,f)

并加载类似

的字段
with open("model/TEXT.Field","rb")as f:
     TEXT=dill.load(f)

官方代码支持正在开发中,您可以遵循https://github.com/pytorch/text/issues/451https://github.com/pytorch/text/issues/73

答案 2 :(得分:0)

您始终可以使用pickle来转储对象,但是请记住,模块并不会处理转储字典或字段列表的对象,因此最好尝试首先分解该列表

要将DataSet对象存储到一个pickle文件中,以便以后轻松加载

def save_to_pickle(dataSetObject,PATH):
    with open(PATH,'wb') as output:
        for i in dataSetObject:
            pickle.dump(vars(i), output, pickle.HIGHEST_PROTOCOL)

最艰难的事情还没有到来,是的,加载泡菜文件。...;)

首先,尝试查找所有字段名称和字段属性,然后进行查杀

要将泡菜文件加载到DataSetObject中

def load_pickle(PATH, FIELDNAMES, FIELD):
    dataList = []
    with open(PATH, "rb") as input_file:
        while True:
            try:
                # Taking the dictionary instance as the input Instance
                inputInstance = pickle.load(input_file)
                # plugging it into the list
                dataInstance =  [inputInstance[FIELDNAMES[0]],inputInstance[FIELDNAMES[1]]]
                # Finally creating an example objects list
                dataList.append(Example().fromlist(dataInstance,fields=FIELD))
            except EOFError:
                break

    # At last creating a data Set Object
    exampleListObject = Dataset(dataList, fields=data_fields)
    return exampleListObject 

这个骇人听闻的解决方案在我的情况下有效,希望您也能从中找到有用的方法。

欢迎提出任何建议:)。

答案 3 :(得分:0)

如果您的数据集很小,则可以使用pickle / dill方法。但是,如果您正在使用大型数据集,由于速度太慢,我不建议您这样做。

我只是简单地(以迭代方式)将示例另存为JSON字符串。这背后的原因是因为保存整个Dataset对象要花费大量时间,另外您还需要诸如莳萝之类的序列化技巧,这会使序列化变得更慢。

此外,这些序列化器会占用大量内存(其中一些甚至会创建数据集的副本),如果它们开始利用交换内存,则说明您完成了。该过程将花费很长时间,您可能会在完成之前终止它。

因此,我最终采用以下方法:

  1. 遍历示例
  2. 将每个示例(即时)转换为JSON字符串
  3. 将该JSON字符串写入文本文件(每个示例一个 行)
  4. 加载时,将示例与字段一起添加到Dataset对象中
def save_examples(dataset, savepath):
    with open(savepath, 'w') as f:
        # Save num. elements (not really need it)
        f.write(json.dumps(total))  # Write examples length
        f.write("\n")

        # Save elements
        for pair in dataset.examples:
            data = [pair.src, pair.trg]
            f.write(json.dumps(data))  # Write samples
            f.write("\n")


def load_examples(filename):
    examples = []
    with open(filename, 'r') as f:
        # Read num. elements (not really need it)
        total = json.loads(f.readline())

        # Save elements
        for i in range(total):
            line = f.readline()
            example = json.loads(line)
            # example = data.Example().fromlist(example, fields)  # Create Example obj. (you can do it here or later)
            examples.append(example)

    end = time.time()
    print(end - start)
    return examples

然后,您可以通过以下方式简单地重建数据集:

# Define fields
SRC = data.Field(...)
TRG = data.Field(...)
fields = [('src', SRC), ('trg', TRG)]

# Load examples from JSON and convert them to "Example objects"
examples = load_examples(filename)
examples = [data.Example().fromlist(d, fields) for d in examples]

# Build dataset
mydataset = Dataset(examples, fields)

我使用JSON代替pickle,dill,msgpack等的原因并不是任意的。

我做了一些测试,这些是结果:

Dataset size: 2x (1,960,641)

Saving times:
- Pickle/Dill*: >30-45 min (...or froze my computer)

- MessagePack (iterative): 123.44 sec
  100%|██████████| 1960641/1960641 [02:03<00:00, 15906.52it/s]

- JSON (iterative): 16.33 sec
  100%|██████████| 1960641/1960641 [00:15<00:00, 125955.90it/s]

- JSON (bulk): 46.54 sec (memory problems)

Loading times:
 - Pickle/Dill*: -

 - MessagePack (iterative): 143.79 sec
   100%|██████████| 1960641/1960641 [02:23<00:00, 13635.20it/s]

 - JSON (iterative): 33.83 sec
   100%|██████████| 1960641/1960641 [00:33<00:00, 57956.28it/s] 

 - JSON (bulk): 27.43 sec

*其他方法类似