如何冻结TFBertForSequenceClassification预训练模型?

时间:2020-07-01 07:25:04

标签: tensorflow huggingface-transformers

如果我使用的是拥抱面变压器的张量流版本,如何冻结预编码器的权重,以便仅优化头层的权重?

对于PyTorch实施,它是通过

完成的
for param in model.base_model.parameters():
    param.requires_grad = False

想为张量流实现做同样的事情。

4 个答案:

答案 0 :(得分:0)

找到了一种方法。在编译基础模型之前先对其进行冻结。

model = TFBertForSequenceClassification.from_pretrained("bert-base-uncased")
model.layers[0].trainable = False
model.compile(...)

答案 1 :(得分:0)

或者:

model.bert.trainable = False

答案 2 :(得分:0)

挖掘这个线程1后,我认为以下代码不会对TF2造成伤害。即使在特定情况下它可能是多余的。

 model = TFBertModel.from_pretrained('./bert-base-uncase')
 for layer in model.layers:
    layer.trainable=False
    for w in layer.weights: w._trainable=False

答案 3 :(得分:0)

model = TFDistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
for _layer in model:
    if _layer.name == 'distilbert':
        print(f"Freezing model layer {_layer.name}")
        _layer.trainable = False
    print(_layer.name)
    print(_layer.trainable)
---
Freezing model layer distilbert
distilbert
False      <----------------
pre_classifier
True
classifier
True
dropout_99
True

Model: "tf_distil_bert_for_sequence_classification_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
distilbert (TFDistilBertMain multiple                  66362880  
_________________________________________________________________
pre_classifier (Dense)       multiple                  590592    
_________________________________________________________________
classifier (Dense)           multiple                  1538      
_________________________________________________________________
dropout_99 (Dropout)         multiple                  0         
=================================================================
Total params: 66,955,010
Trainable params: 592,130
Non-trainable params: 66,362,880   <-----

不冻结。

Model: "tf_distil_bert_for_sequence_classification_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
distilbert (TFDistilBertMain multiple                  66362880  
_________________________________________________________________
pre_classifier (Dense)       multiple                  590592    
_________________________________________________________________
classifier (Dense)           multiple                  1538      
_________________________________________________________________
dropout_59 (Dropout)         multiple                  0         
=================================================================
Total params: 66,955,010
Trainable params: 66,955,010
Non-trainable params: 0

请相应地从 TFDistilBertForSequenceClassification 更改为 TFBertForSequenceClassification。为此,首先运行 model.summary 以验证基本名称。对于 TFDistilBertForSequenceClassification,它是 distilbert