我正在训练一个使用Bert(拥抱面)的二进制分类器。该模型如下所示:
def get_model(lr=0.00001):
inp_bert = Input(shape=(512), dtype="int32")
bert = TFBertModel.from_pretrained('bert-base-multilingual-cased')(inp_bert)[0]
doc_encodings = tf.squeeze(bert[:, 0:1, :], axis=1)
out = Dense(1, activation="sigmoid")(doc_encodings)
model = Model(inp_bert, out)
adam = optimizers.Adam(lr=lr)
model.compile(optimizer=adam, loss="binary_crossentropy", metrics=["accuracy"])
return model
对分类任务进行微调后,我想保存模型。
model.save("best_model.h5")
但是这会引发NotImplementedError:
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
<ipython-input-55-8c5545f0cd9b> in <module>()
----> 1 model.save("best_spam.h5")
2 # import transformers
~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options)
973 """
974 saving.save_model(self, filepath, overwrite, include_optimizer, save_format,
--> 975 signatures, options)
976
977 def save_weights(self, filepath, overwrite=True, save_format=None):
~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options)
110 'or using `save_weights`.')
111 hdf5_format.save_model_to_hdf5(
--> 112 model, filepath, overwrite, include_optimizer)
113 else:
114 saved_model_save.save(model, filepath, overwrite, include_optimizer,
~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/hdf5_format.py in save_model_to_hdf5(model, filepath, overwrite, include_optimizer)
97
98 try:
---> 99 model_metadata = saving_utils.model_metadata(model, include_optimizer)
100 for k, v in model_metadata.items():
101 if isinstance(v, (dict, list, tuple)):
~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/saving_utils.py in model_metadata(model, include_optimizer, require_config)
163 except NotImplementedError as e:
164 if require_config:
--> 165 raise e
166
167 metadata = dict(
~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/saving/saving_utils.py in model_metadata(model, include_optimizer, require_config)
160 model_config = {'class_name': model.__class__.__name__}
161 try:
--> 162 model_config['config'] = model.get_config()
163 except NotImplementedError as e:
164 if require_config:
~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_config(self)
885 if not self._is_graph_network:
886 raise NotImplementedError
--> 887 return copy.deepcopy(get_network_config(self))
888
889 @classmethod
~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_network_config(network, serialize_layer_fn)
1940 filtered_inbound_nodes.append(node_data)
1941
-> 1942 layer_config = serialize_layer_fn(layer)
1943 layer_config['name'] = layer.name
1944 layer_config['inbound_nodes'] = filtered_inbound_nodes
~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
138 if hasattr(instance, 'get_config'):
139 return serialize_keras_class_and_config(instance.__class__.__name__,
--> 140 instance.get_config())
141 if hasattr(instance, '__name__'):
142 return instance.__name__
~/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow_core/python/keras/engine/network.py in get_config(self)
884 def get_config(self):
885 if not self._is_graph_network:
--> 886 raise NotImplementedError
887 return copy.deepcopy(get_network_config(self))
888
NotImplementedError:
我知道拥抱面为TFBertModel提供了model.save_pretrained()方法,但是我打算将其包装在tf.keras.Model中,因为我打算向该网络添加其他组件/功能。谁能建议一种保存当前模型的解决方案?
答案 0 :(得分:4)
这确实是tensorflow 2.0的问题。
请使用:model.save("model_name",save_format='tf')
或者,您也可以尝试升级或降级张量流。