RuntimeError:为BertModel

时间:2019-10-18 05:34:40

标签: machine-learning deep-learning pytorch transfer-learning

我使用拥抱的脸部变形器库微调BERT模型,并在云中的GPU中对其进行训练。然后,如下所示保存模型和令牌生成器:

model.save_pretrained('/saved_model/')
torch.save(best_model.state_dict(), '/saved_model/model')
tokenizer.save_pretrained('/saved_model/')

我在计算机上下载了saved_model目录。然后,将如下所示的模型/令牌加载到计算机中

import torch
from transformers import *
tokenizer = BertTokenizer.from_pretrained('./saved_model/')
config = BertConfig('./saved_model/config.json')
model = BertModel(config)
model.load_state_dict(torch.load('./saved_model/pytorch_model.bin', map_location=torch.device('cpu')))
model.eval()

但是对于model.load_state_dict行,它抛出以下错误

RuntimeError: Error(s) in loading state_dict for BertModel:
    Missing key(s) in state_dict:

它列出了一堆显然在state_dict中缺少的键。

我是pytorch的新手,不确定发生了什么。很可能我没有以正确的方式保存模型。

请提出建议。

1 个答案:

答案 0 :(得分:1)

您可能知道,PyTorch模块的state_dictOrderedDict。当您尝试从state_dict加载模块的权重时,它抱怨缺少键,这意味着state_dict不包含那些键。在这种情况下,我建议采取以下措施。

  1. 检查state_dict中存在哪些键。听起来只保存一部分键是不可能的。
  2. 此外,请确保已加载正确的配置。否则,如果您训练有素的BertModel和要为其加载权重的新BertModel不同,那么您将收到此错误。
  3. 最后,如果您的代码通过了上述两种情况,然后保存模型,请确保将所有图层的参数保存在文件中。语句torch.save(best_model.state_dict(), '/saved_model/model')在我看来还不错,但请确保best_model.state_dict()包含所有预期的密钥。