编辑预训练的 BERT 模型的架构

时间:2021-06-15 13:48:22

标签: tensorflow keras nlp bert-language-model transfer-learning

我从 Google 的 GitHub 存储库中找到了一个 BERT 模型。下载并获取其 json 配置文件并加载模型。

import json

bert_config_file = os.path.join(gs_folder_bert, "/content/drive/My Drive/Colab Notebooks/bert_config.json")

config_dict = json.loads(tf.io.gfile.GFile(bert_config_file).read())

bert_config = bert.configs.BertConfig.from_dict(config_dict)

bert_classifier, bert_encoder = bert.bert_models.classifier_model(bert_config, num_labels=2)
print(bert_encoder.summary())

Model: "bert_encoder_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_word_ids (InputLayer)     [(None, None)]       0                                            
__________________________________________________________________________________________________
word_embeddings (OnDeviceEmbedd (None, None, 128)    3906816     input_word_ids[0][0]             
__________________________________________________________________________________________________
input_type_ids (InputLayer)     [(None, None)]       0                                            
__________________________________________________________________________________________________
position_embedding (PositionEmb (None, None, 128)    65536       word_embeddings[0][0]            
__________________________________________________________________________________________________
type_embeddings (OnDeviceEmbedd (None, None, 128)    256         input_type_ids[0][0]             
__________________________________________________________________________________________________
add (Add)                       (None, None, 128)    0           word_embeddings[0][0]            
                                                                 position_embedding[0][0]         
                                                                 type_embeddings[0][0]            
__________________________________________________________________________________________________
embeddings/layer_norm (LayerNor (None, None, 128)    256         add[0][0]                        
__________________________________________________________________________________________________
dropout (Dropout)               (None, None, 128)    0           embeddings/layer_norm[0][0]      
__________________________________________________________________________________________________
input_mask (InputLayer)         [(None, None)]       0                                            
__________________________________________________________________________________________________
self_attention_mask (SelfAttent (None, None, None)   0           dropout[0][0]                    
                                                                 input_mask[0][0]                 
__________________________________________________________________________________________________
transformer/layer_0 (Transforme (None, None, 128)    198272      dropout[0][0]                    
                                                                 self_attention_mask[0][0]        
__________________________________________________________________________________________________
transformer/layer_1 (Transforme (None, None, 128)    198272      transformer/layer_0[0][0]        
                                                                 self_attention_mask[0][0]        
__________________________________________________________________________________________________
tf.__operators__.getitem (Slici (None, 128)          0           transformer/layer_1[0][0]        
__________________________________________________________________________________________________
pooler_transform (Dense)        (None, 128)          16512       tf.__operators__.getitem[0][0]   
==================================================================================================
Total params: 4,385,920
Trainable params: 4,385,920
Non-trainable params: 0

我想要一个模型,我可以将自己的自定义嵌入向量直接输入到转换器层 0,这是模型中的第 11 层。所以我不需要前 10 个预处理层。最后,我想更改输出激活,以便我可以进行回归。

如何编辑模型的架构或创建具有这些转换器层的新模型?提前致谢。

0 个答案:

没有答案
相关问题