为什么在Huggingface Transformers中的BERT预训练模型中需要init_weight函数?

时间:2020-05-27 09:54:16

标签: python huggingface-transformers bert-language-model

在Hugginface转换器的代码中,有许多具有功能init_weight的微调模型。 例如(here),最后有一个init_weight函数。

class BertForSequenceClassification(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels

        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

        self.init_weights()

据我所知,它将调用以下code

def _init_weights(self, module):
    """ Initialize the weights """
    if isinstance(module, (nn.Linear, nn.Embedding)):
        # Slightly different from the TF version which uses truncated_normal for initialization
        # cf https://github.com/pytorch/pytorch/pull/5617
        module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
    elif isinstance(module, BertLayerNorm):
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)
    if isinstance(module, nn.Linear) and module.bias is not None:
        module.bias.data.zero_()

我的问题是如果我们要加载预先训练的模型,为什么我们需要为每个模块初始化权重?

我想我一定是误会了。

3 个答案:

答案 0 :(得分:1)

看看.from_pretrained()的代码。实际发生的是这样的:

  • 找到正确的基础模型类进行初始化
  • 使用伪随机初始化初始化该类(通过使用您提到的_init_weights函数)
  • 使用预先设置的权重查找文件
  • 在适用的情况下使用预先训练的权重覆盖我们刚刚创建的模型的权重

这可确保未经预训练的层(例如,在某些情况下为最终分类层)在_init_weights中初始化,但不会被覆盖。

答案 1 :(得分:0)

BertPreTrainedModel是一个抽象类,如果您检查的话,错误是BertPreTrainedModel类甚至没有构造函数,甚至认为它正在被调用,您可以使用PR来修饰这段代码,但是确保首先创建问题。

答案 2 :(得分:0)

因此,在使用转换器库时,我是否需要调用init_weights()

我的意思是,下面的代码足够吗?

class SentimentClassifier(nn.Module):
def __init__(self, num_classes):
    super(SentimentClassifier, self).__init__()
    self.bert_layer = BertModel.from_pretrained('bert-base-uncased')

    self.cls_layer = nn.Linear(768, num_classes)


def forward().....