在Tensorflow 2.0中保存和加载模型

时间:2020-06-03 14:59:29

标签: python tensorflow machine-learning

我使用此代码从tensorflow 2.x中的预制估计器中保存了一个模型

import os
serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
tf.feature_column.make_parse_example_spec(my_feature_columns))
estimator_base_path = os.path.join( 'from_estimator')
estimator_path = classifier.export_saved_model(estimator_base_path, serving_input_fn)

此代码创建一个包含.pb文件的文件夹 我将来需要重用此模型,我尝试加载此功能

saved_model_obj = tf.compat.v2.saved_model.load(export_dir="/model_dir/")

但是当我尝试使用加载的模型做出预测时,会引发此错误

predictions = saved_model_obj.predict(
input_fn=lambda: input_fn(predict_x))


---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-23-a9902ff8210c> in <module>
----> 1 predictions = saved_model_obj.predict(
      2     input_fn=lambda: input_fn(predict_x))

AttributeError: 'AutoTrackable' object has no attribute 'predict'

如何加载.pb文件并进行预测,就像我从未保存过并加载它一样?

1 个答案:

答案 0 :(得分:0)

当我保存模型以备后用时,通常会这样做:

假设您的模型是model

model.save('my_model.h5') 

这会将modoel保存为hdf5格式。

然后当我不得不用它来再次预测时,我可以:

new_model = tf.keras.models.load_model('my_model.h5')

然后您可以new_model.predict()