如果我使用的是拥抱面变压器的张量流版本,如何冻结预编码器的权重,以便仅优化头层的权重?
对于PyTorch实施,它是通过
完成的for param in model.base_model.parameters():
param.requires_grad = False
想为张量流实现做同样的事情。
答案 0 :(得分:0)
找到了一种方法。在编译基础模型之前先对其进行冻结。
model = TFBertForSequenceClassification.from_pretrained("bert-base-uncased")
model.layers[0].trainable = False
model.compile(...)
答案 1 :(得分:0)
或者:
model.bert.trainable = False
答案 2 :(得分:0)
挖掘这个线程1后,我认为以下代码不会对TF2造成伤害。即使在特定情况下它可能是多余的。
model = TFBertModel.from_pretrained('./bert-base-uncase')
for layer in model.layers:
layer.trainable=False
for w in layer.weights: w._trainable=False
答案 3 :(得分:0)
model = TFDistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
for _layer in model:
if _layer.name == 'distilbert':
print(f"Freezing model layer {_layer.name}")
_layer.trainable = False
print(_layer.name)
print(_layer.trainable)
---
Freezing model layer distilbert
distilbert
False <----------------
pre_classifier
True
classifier
True
dropout_99
True
Model: "tf_distil_bert_for_sequence_classification_4"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
distilbert (TFDistilBertMain multiple 66362880
_________________________________________________________________
pre_classifier (Dense) multiple 590592
_________________________________________________________________
classifier (Dense) multiple 1538
_________________________________________________________________
dropout_99 (Dropout) multiple 0
=================================================================
Total params: 66,955,010
Trainable params: 592,130
Non-trainable params: 66,362,880 <-----
不冻结。
Model: "tf_distil_bert_for_sequence_classification_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
distilbert (TFDistilBertMain multiple 66362880
_________________________________________________________________
pre_classifier (Dense) multiple 590592
_________________________________________________________________
classifier (Dense) multiple 1538
_________________________________________________________________
dropout_59 (Dropout) multiple 0
=================================================================
Total params: 66,955,010
Trainable params: 66,955,010
Non-trainable params: 0
请相应地从 TFDistilBertForSequenceClassification 更改为 TFBertForSequenceClassification。为此,首先运行 model.summary
以验证基本名称。对于 TFDistilBertForSequenceClassification,它是 distilbert
。