使用tensorflow 2.0保存和加载微调的bert分类模型

时间:2020-11-09 19:24:45

标签: nlp tensorflow2.0 bert-language-model

我正在尝试基于预训练的Bert模块'uncased_L-12_H-768_A-12'保存一个经过微调的二进制分类模型。我正在使用tf2

代码建立了模型结构

bert_classifier, bert_encoder =bert.bert_models.classifier_model(bert_config, num_labels=2)

然后:

# import pre-trained model structure from the check point file
checkpoint = tf.train.Checkpoint(model=bert_encoder)
checkpoint.restore(
    os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed()

然后:我编译并拟合了模型

bert_classifier.compile(
    optimizer=optimizer,
    loss=loss,
    metrics=metrics)

bert_classifier.fit(
      Text_train, Label_train,
      validation_data=(Text_val, Label_val),
      batch_size=32,
      epochs=1)

最后:我将模型保存在模型文件夹中,然后在其中自动生成一个名为save_model.pb的文件。

bert_classifier.save('/content/drive/My Drive/model')

also tried this

tf.saved_model.save(bert_classifier, export_dir='/content/drive/My Drive/omg')

现在我尝试加载模型并将其应用于测试数据:

from tensorflow import keras

ttt = keras.models.load_model('/content/drive/My Drive/model')

我知道了

KeyError                                  Traceback (most recent call last)
<ipython-input-77-93f80aa585da> in <module>()
----> 1 tf.keras.models.load_model(filepath='/content/drive/My Drive/omg', custom_objects={'Transformera':bert_classifier})

9 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/load.py in _revive_graph_network(self, metadata, node_id)
    392     else:
    393       model = models_lib.Functional(
--> 394           inputs=[], outputs=[], name=config['name'])
    395 
    396     # Record this model and its layers. This will later be used to reconstruct

KeyError: 'name'

此错误消息对我的操作没有帮助...请提供建议。

我还尝试将模型保存为h5格式,但是在加载时

ttt = keras.models.load_model('/content/drive/My Drive/model.h5')

我遇到了这个错误

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-36-12f76139ec24> in <module>()
----> 1 ttt = keras.models.load_model('/content/drive/My Drive/model.h5')

5 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in class_and_config_for_serialized_keras_object(config, module_objects, custom_objects, printable_module_name)
    294   cls = get_registered_object(class_name, custom_objects, module_objects)
    295   if cls is None:
--> 296     raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
    297 
    298   cls_config = config['config']

ValueError: Unknown layer: BertClassifier

1 个答案:

答案 0 :(得分:1)

似乎您对问题有正确的答案:'/content/drive/My Drive/model' 将因空格字符而失败。

您可以尝试转义退格:'/content/drive/My\ Drive/model'

其他选项,在我在保存和加载时遇到完全相同的问题之后。有帮助的是只保存预训练模型的权重而不保存整个模型:

看看这里:https://keras.io/api/models/model_saving_apis/,尤其是方法 save_weights()load_weights()