我在使用tf.SavedModel恢复tf.estimator.BoostedTreesRegressor模型时遇到问题。使用tf.contrib.predictor.from_saved_model()从保存的模型目录中重新加载模型时,出现以下错误:
KeyError:“名称'boosted_trees / QuantileAccumulator /'是指 操作不在图中。”
此错误仅在使用数字功能(例如tf.feature_column.numeric_column)时发生。仅使用分类列时,重新加载模型效果很好
当我不保存/还原时,BoostedTreesRegressor会使用所有功能评估并成功预测。
以下估算器保存/还原方案已成功运行:
-具有数字和分类特征的DNNRegressor
-具有数字和分类功能的LinearRegressor
-带有分类功能的BoostedTreeRegressor
fc = tf.feature_column
feature_columns = [
fc.numeric_column('f1', dtype=tf.int64),
fc.numeric_column('f2', dtype=tf.int64),
fc.indicator_column(
fc.categorical_column_with_vocabulary_list('f3',f3)),
fc.indicator_column(
fc.categorical_column_with_vocabulary_list('f4',f4))
]
feature_spec = fc.make_parse_example_spec(feature_columns)
params = {
'feature_columns' : feature_columns,
'n_batches_per_layer' : n_batches,
'n_trees': 200,
'max_depth': 6,
'learning_rate': 0.01
}
regressor = tf.estimator.BoostedTreesRegressor(**params)
regressor.train(train_input_fn, max_steps=400)
serving_input_receiver_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)
regressor.export_saved_model('saved_model', serving_input_receiver_fn)
.
.
.
# latest is path to saved model
predict_fn = predictor.from_saved_model(latest[:-4])
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-101-ee20beae4424> in <module>
----> 1 predict_fn = predictor.from_saved_model(latest[:-4])
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/contrib/predictor/predictor_factories.py in from_saved_model(export_dir, signature_def_key, signature_def, input_names, output_names, tags, graph, config)
151 tags=tags,
152 graph=graph,
--> 153 config=config)
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/contrib/predictor/saved_model_predictor.py in __init__(self, export_dir, signature_def_key, signature_def, input_names, output_names, tags, graph, config)
151 with self._graph.as_default():
152 self._session = session.Session(config=config)
--> 153 loader.load(self._session, tags.split(','), export_dir)
154
155 if input_names is None:
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
322 'in a future version' if date is None else ('after %s' % date),
323 instructions)
--> 324 return func(*args, **kwargs)
325 return tf_decorator.make_decorator(
326 func, new_func, 'deprecated',
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/saved_model/loader_impl.py in load(sess, tags, export_dir, import_scope, **saver_kwargs)
267 """
268 loader = SavedModelLoader(export_dir)
--> 269 return loader.load(sess, tags, import_scope, **saver_kwargs)
270
271
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/saved_model/loader_impl.py in load(self, sess, tags, import_scope, **saver_kwargs)
418 with sess.graph.as_default():
419 saver, _ = self.load_graph(sess.graph, tags, import_scope,
--> 420 **saver_kwargs)
421 self.restore_variables(sess, saver, import_scope)
422 self.run_init_ops(sess, tags, import_scope)
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/saved_model/loader_impl.py in load_graph(self, graph, tags, import_scope, **saver_kwargs)
348 with graph.as_default():
349 return tf_saver._import_meta_graph_with_return_elements( # pylint: disable=protected-access
--> 350 meta_graph_def, import_scope=import_scope, **saver_kwargs)
351
352 def restore_variables(self, sess, saver, import_scope=None):
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/training/saver.py in _import_meta_graph_with_return_elements(meta_graph_or_file, clear_devices, import_scope, return_elements, **kwargs)
1455 import_scope=import_scope,
1456 return_elements=return_elements,
-> 1457 **kwargs))
1458
1459 saver = _create_saver_from_imported_meta_graph(
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/framework/meta_graph.py in import_scoped_meta_graph_with_return_elements(meta_graph_or_file, clear_devices, graph, import_scope, input_map, unbound_inputs_col_name, restore_collections_predicate, return_elements)
850 for value in field.value:
851 col_op = graph.as_graph_element(
--> 852 ops.prepend_name_scope(value, scope_to_prepend_to_names))
853 graph.add_to_collection(key, col_op)
854 elif kind == "int64_list":
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in as_graph_element(self, obj, allow_tensor, allow_operation)
3476
3477 with self._lock:
-> 3478 return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
3479
3480 def _as_graph_element_locked(self, obj, allow_tensor, allow_operation):
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in _as_graph_element_locked(self, obj, allow_tensor, allow_operation)
3536 if name not in self._nodes_by_name:
3537 raise KeyError("The name %s refers to an Operation not in the "
-> 3538 "graph." % repr(name))
3539 return self._nodes_by_name[name]
3540
KeyError: "The name 'boosted_trees/QuantileAccumulator/' refers to an Operation not in the graph."
答案 0 :(得分:0)
如果您使用的是Tensorflow版本1.x (1.14, 1.15)
,则可以使用
tf.compat.v1.saved_model.load
或tf.compat.v1.saved_model.loader.load
或
tf.saved_model.loader.load
加载保存的模型。
如果您使用的是 Tensorflow Version 2
,下面是 Saving
和 Restoring
的代码成功使用 tf.estimator.BoostedTreesClassifier
:
n_batches = 1
est = tf.estimator.BoostedTreesClassifier(feature_columns,
n_batches_per_layer=n_batches)
# The model will stop training once the specified number of trees is built, not
# based on the number of steps.
est.train(train_input_fn, max_steps=100)
# Eval.
result = est.evaluate(eval_input_fn)
clear_output()
print(pd.Series(result))
feature_spec = fc.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)
Exported_Path = est.export_saved_model('saved_model', serving_input_receiver_fn)
imported = tf.saved_model.load(Exported_Path)
有关使用 Tensorflow Version 2
的完整工作代码,请找到此Github Gist。