我的模型是通过tf.estimator.Estimator
API创建,训练和保存的。
通常,我只会加载并进行预测
classifier = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=model_dir)
preds = classifier.predict(
input_fn=input_fn)
但是,当我使用SavedModel API时,我无法找到一种方法来处理基本相同的事情。
g1 = tf.Graph()
with tf.Session(graph = g1) as sess:
meta_graph_def = tf.saved_model.loader.load(
sess=sess,
tags=[tf.saved_model.tag_constants.SERVING],
export_dir=model_dir)
# what to do next???