如何加载部分预训练的pytorch模型?

时间:2020-04-14 15:42:17

标签: python machine-learning pytorch pre-trained-model spacy-pytorch-transformers

我正在尝试在句子分类任务上运行pytorch模型。在处理医学笔记时,我正在使用ClinicalBert(https://github.com/kexinhuang12345/clinicalBERT),并希望使用其预先训练的权重。不幸的是,当我有281个二进制标签时,ClinicalBert模型仅将文本分类为1个二进制标签。因此,我正在尝试实现此代码https://github.com/kaushaltrivedi/bert-toxic-comments-multilabel/blob/master/toxic-bert-multilabel-classification.ipynb,其中bert后的末端分类器长281。

如何在不加载分类权重的情况下从ClinicalBert模型加载预先训练的Bert权重?

天真地尝试从预先训练的ClinicalBert重量中加载重量,但出现以下错误:

size mismatch for classifier.weight: copying a param with shape torch.Size([2, 768]) from checkpoint, the shape in current model is torch.Size([281, 768]).
size mismatch for classifier.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([281]).

我目前尝试从pytorch_pretrained_bert包中替换from_pretrained函数,并弹出分类器权重和偏倚,如下所示:

def from_pretrained(cls, pretrained_model_name, state_dict=None, cache_dir=None, *inputs, **kwargs):
    ...
    if state_dict is None:
        weights_path = os.path.join(serialization_dir, WEIGHTS_NAME)
        state_dict = torch.load(weights_path, map_location='cpu')
    state_dict.pop('classifier.weight')
    state_dict.pop('classifier.bias')
    old_keys = []
    new_keys = []
    ...

并且我收到以下错误消息: 信息-模型诊断-未从预训练模型初始化BertForMultiLabelSequenceClassification的权重:['classifier.weight','classifier.bias']

最后,我想从ClinicalBert预训练权重中加载bert嵌入,并随机初始化顶级分类器权重。

1 个答案:

答案 0 :(得分:1)

在加载之前删除状态dict中的键是一个好的开始。假设您正在使用nn.Module.load_state_dict来加载预训练的权重,那么您还需要设置strict=False参数,以避免因意外或丢失键而导致的错误。这将忽略state_dict中模型中不存在的条目(意外的键),并且对您而言更重要的是,将使缺失的条目保留其默认初始化(缺少键)。为了安全起见,您可以检查该方法的返回值,以验证所涉及的砝码是丢失的钥匙的一部分,并且没有任何意外的钥匙。