通过TensorFlow估算器模型(2.0)保存,加载和预测

时间:2019-11-20 16:55:38

标签: tensorflow

在任何地方都可以找到有关在TF2中序列化和还原Estimator模型的指南吗?该文档非常参差不齐,其中许多文档未更新为TF2。我还没有看到Estimator被保存,从磁盘加载并用于根据新输入进行预测的任何地方的清晰完整示例。

TBH,这看起来有多复杂,我有点困惑。估计器被认为是拟合标准模型的简单,相对高级的方法,但是在生产中使用估计器的过程似乎非常不可思议。例如,当我通过tf.saved_model.load(export_path)从磁盘加载模型时,会得到一个AutoTrackable对象:

<tensorflow.python.training.tracking.tracking.AutoTrackable at 0x7fc42e779f60>

不清楚我为什么不找回Estimator。似乎曾经有一个有用的发音功能tf.contrib.predictor.from_saved_model,但是由于contrib消失了,它似乎不再起作用了(除了,它出现在TFLite中)。

任何指针都将非常有帮助。如您所见,我有点迷茫。

1 个答案:

答案 0 :(得分:3)

也许作者不再需要答案了,但是我能够使用TensorFlow 2.1保存并加载DNNClassifier。

# training.py
from pathlib import Path
import tensorflow as tf

....
# Creating the estimator
estimator = tf.estimator.DNNClassifier(
    model_dir = <model_dir>,
    hidden_units = [1000, 500],
    feature_columns = feature_columns, # this is a list defined earlier
    n_classes = 2,
    optimizer = 'adam')

feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)
export_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)
servable_model_path = Path(estimator.export_saved_model(<model_dir>, export_input_fn).decode('utf8'))
print(f'Model saved at {servable_model_path}')

对于加载,您找到了正确的方法,只需要检索 predict_fn

# testing.py
import tensorflow as tf
import pandas as pd

def predict_input_fn(test_df):
    '''Convert your dataframe using tf.train.Example() and tf.train.Features()'''
    examples = []
    ....
    return tf.constant(examples)

test_df = pd.read_csv('test.csv', ...)

# Loading the estimator
predict_fn = tf.saved_model.load(<model_dir>).signatures['predict']
# Predict
predictions = predict_fn(examples=predict_input_fn(test_df))

希望这也可以帮助其他人(: