我正在尝试基于预训练的Bert模块'uncased_L-12_H-768_A-12'保存一个经过微调的二进制分类模型。我正在使用tf2
代码建立了模型结构
bert_classifier, bert_encoder =bert.bert_models.classifier_model(bert_config, num_labels=2)
然后:
# import pre-trained model structure from the check point file
checkpoint = tf.train.Checkpoint(model=bert_encoder)
checkpoint.restore(
os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed()
然后:我编译并拟合了模型
bert_classifier.compile(
optimizer=optimizer,
loss=loss,
metrics=metrics)
bert_classifier.fit(
Text_train, Label_train,
validation_data=(Text_val, Label_val),
batch_size=32,
epochs=1)
最后:我将模型保存在模型文件夹中,然后在其中自动生成一个名为save_model.pb的文件。
bert_classifier.save('/content/drive/My Drive/model')
also tried this
tf.saved_model.save(bert_classifier, export_dir='/content/drive/My Drive/omg')
现在我尝试加载模型并将其应用于测试数据:
from tensorflow import keras
ttt = keras.models.load_model('/content/drive/My Drive/model')
我知道了
KeyError Traceback (most recent call last)
<ipython-input-77-93f80aa585da> in <module>()
----> 1 tf.keras.models.load_model(filepath='/content/drive/My Drive/omg', custom_objects={'Transformera':bert_classifier})
9 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/saved_model/load.py in _revive_graph_network(self, metadata, node_id)
392 else:
393 model = models_lib.Functional(
--> 394 inputs=[], outputs=[], name=config['name'])
395
396 # Record this model and its layers. This will later be used to reconstruct
KeyError: 'name'
此错误消息对我的操作没有帮助...请提供建议。
我还尝试将模型保存为h5格式,但是在加载时
ttt = keras.models.load_model('/content/drive/My Drive/model.h5')
我遇到了这个错误
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-36-12f76139ec24> in <module>()
----> 1 ttt = keras.models.load_model('/content/drive/My Drive/model.h5')
5 frames
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in class_and_config_for_serialized_keras_object(config, module_objects, custom_objects, printable_module_name)
294 cls = get_registered_object(class_name, custom_objects, module_objects)
295 if cls is None:
--> 296 raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
297
298 cls_config = config['config']
ValueError: Unknown layer: BertClassifier
答案 0 :(得分:1)
似乎您对问题有正确的答案:'/content/drive/My Drive/model'
将因空格字符而失败。
您可以尝试转义退格:'/content/drive/My\ Drive/model'
。
其他选项,在我在保存和加载时遇到完全相同的问题之后。有帮助的是只保存预训练模型的权重而不保存整个模型:
看看这里:https://keras.io/api/models/model_saving_apis/,尤其是方法 save_weights()
和 load_weights()
。