以AllenNLP软件包保存/加载模型

时间:2019-01-17 01:47:00

标签: pytorch allennlp

我正在尝试加载AllenNLP模型权重。我找不到有关如何保存/加载整个模型的任何文档,因此只能玩重物。

from allennlp.nn import util
model_state = torch.load(filename_model, map_location=util.device_mapping(-1))
model.load_state_dict(model_state)

我稍微修改了输入语料库,因此我猜测是因为我得到语料库大小不匹配:

RuntimeError: Error(s) in loading state_dict for BasicTextFieldEmbedder:

    size mismatch for token_embedder_tokens.weight: 
    copying a param with shape torch.Size([2117, 16]) from checkpoint, 
    the shape in current model is torch.Size([2129, 16]).

似乎没有官方的方法来保存语料库词汇。周围有什么骇客吗?

1 个答案:

答案 0 :(得分:0)

AllenNLP中有一项功能可以加载或保存模型。 您是否遵循AllenNLP's tutorial中概述的步骤?下面,我粘贴了您可能感兴趣的教程片段:

# Here's how to save the model.
with open("/tmp/model.th", 'wb') as f:
    torch.save(model.state_dict(), f)

vocab.save_to_files("/tmp/vocabulary")

# And here's how to reload the model.
vocab2 = Vocabulary.from_files("/tmp/vocabulary")

model2 = LstmTagger(word_embeddings, lstm, vocab2)
with open("/tmp/model.th", 'rb') as f:
    model2.load_state_dict(torch.load(f))

如果上述方法对您不起作用,则可以检查allennlp.models.archival.archive_model助手功能。使用此功能,您应该能够将模型的训练配置以及权重和词汇表归档到model.tar.gz。 Here,您可以找到有关我所讨论的两种方法的局限性的更多信息