tf.estimator.BoostedTreesRegressor SavedModel还原问题

时间:2019-06-20 18:55:41

标签: tensorflow tensorflow-estimator

我在使用tf.SavedModel恢复tf.estimator.BoostedTreesRegressor模型时遇到问题。使用tf.contrib.predictor.from_saved_model()从保存的模型目录中重新加载模型时,出现以下错误:

  

KeyError:“名称'boosted_trees / QuantileAccumulator /'是指   操作不在图中。”

此错误仅在使用数字功能(例如tf.feature_column.numeric_column)时发生。仅使用分类列时,重新加载模型效果很好

当我不保存/还原时,BoostedTreesRegressor会使用所有功能评估并成功预测。

以下估算器保存/还原方案已成功运行:
-具有数字和分类特征的DNNRegressor
-具有数字和分类功能的LinearRegressor
-带有分类功能的BoostedTreeRegressor

fc = tf.feature_column
feature_columns = [
fc.numeric_column('f1', dtype=tf.int64),
fc.numeric_column('f2', dtype=tf.int64),
fc.indicator_column(
               fc.categorical_column_with_vocabulary_list('f3',f3)),
fc.indicator_column(
               fc.categorical_column_with_vocabulary_list('f4',f4))
]

feature_spec = fc.make_parse_example_spec(feature_columns)

params = {
    'feature_columns' : feature_columns,
    'n_batches_per_layer' : n_batches,
    'n_trees': 200,
    'max_depth': 6,
    'learning_rate': 0.01
}

regressor = tf.estimator.BoostedTreesRegressor(**params)
regressor.train(train_input_fn, max_steps=400)

serving_input_receiver_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)

regressor.export_saved_model('saved_model', serving_input_receiver_fn)

.
.
.
# latest is path to saved model
predict_fn = predictor.from_saved_model(latest[:-4])
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-101-ee20beae4424> in <module>
----> 1 predict_fn = predictor.from_saved_model(latest[:-4])
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/contrib/predictor/predictor_factories.py in from_saved_model(export_dir, signature_def_key, signature_def, input_names, output_names, tags, graph, config)
    151       tags=tags,
    152       graph=graph,
--> 153       config=config)
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/contrib/predictor/saved_model_predictor.py in __init__(self, export_dir, signature_def_key, signature_def, input_names, output_names, tags, graph, config)
    151     with self._graph.as_default():
    152       self._session = session.Session(config=config)
--> 153       loader.load(self._session, tags.split(','), export_dir)
    154 
    155     if input_names is None:
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    322               'in a future version' if date is None else ('after %s' % date),
    323               instructions)
--> 324       return func(*args, **kwargs)
    325     return tf_decorator.make_decorator(
    326         func, new_func, 'deprecated',
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/saved_model/loader_impl.py in load(sess, tags, export_dir, import_scope, **saver_kwargs)
    267   """
    268   loader = SavedModelLoader(export_dir)
--> 269   return loader.load(sess, tags, import_scope, **saver_kwargs)
    270 
    271 
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/saved_model/loader_impl.py in load(self, sess, tags, import_scope, **saver_kwargs)
    418     with sess.graph.as_default():
    419       saver, _ = self.load_graph(sess.graph, tags, import_scope,
--> 420                                  **saver_kwargs)
    421       self.restore_variables(sess, saver, import_scope)
    422       self.run_init_ops(sess, tags, import_scope)
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/saved_model/loader_impl.py in load_graph(self, graph, tags, import_scope, **saver_kwargs)
    348     with graph.as_default():
    349       return tf_saver._import_meta_graph_with_return_elements(  # pylint: disable=protected-access
--> 350           meta_graph_def, import_scope=import_scope, **saver_kwargs)
    351 
    352   def restore_variables(self, sess, saver, import_scope=None):
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/training/saver.py in _import_meta_graph_with_return_elements(meta_graph_or_file, clear_devices, import_scope, return_elements, **kwargs)
   1455           import_scope=import_scope,
   1456           return_elements=return_elements,
-> 1457           **kwargs))
   1458 
   1459   saver = _create_saver_from_imported_meta_graph(
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/framework/meta_graph.py in import_scoped_meta_graph_with_return_elements(meta_graph_or_file, clear_devices, graph, import_scope, input_map, unbound_inputs_col_name, restore_collections_predicate, return_elements)
    850           for value in field.value:
    851             col_op = graph.as_graph_element(
--> 852                 ops.prepend_name_scope(value, scope_to_prepend_to_names))
    853             graph.add_to_collection(key, col_op)
    854         elif kind == "int64_list":
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in as_graph_element(self, obj, allow_tensor, allow_operation)
   3476 
   3477     with self._lock:
-> 3478       return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
   3479 
   3480   def _as_graph_element_locked(self, obj, allow_tensor, allow_operation):
/usr/local/anaconda3/envs/zume/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in _as_graph_element_locked(self, obj, allow_tensor, allow_operation)
   3536         if name not in self._nodes_by_name:
   3537           raise KeyError("The name %s refers to an Operation not in the "
-> 3538                          "graph." % repr(name))
   3539         return self._nodes_by_name[name]
   3540 
KeyError: "The name 'boosted_trees/QuantileAccumulator/' refers to an Operation not in the graph."

1 个答案:

答案 0 :(得分:0)

如果您使用的是Tensorflow版本1.x (1.14, 1.15),则可以使用

tf.compat.v1.saved_model.loadtf.compat.v1.saved_model.loader.loadtf.saved_model.loader.load加载保存的模型。

如果您使用的是 Tensorflow Version 2 ,下面是 Saving Restoring 的代码成功使用 tf.estimator.BoostedTreesClassifier

n_batches = 1
est = tf.estimator.BoostedTreesClassifier(feature_columns,
                                          n_batches_per_layer=n_batches)

# The model will stop training once the specified number of trees is built, not
# based on the number of steps.
est.train(train_input_fn, max_steps=100)

# Eval.
result = est.evaluate(eval_input_fn)
clear_output()
print(pd.Series(result))

feature_spec = fc.make_parse_example_spec(feature_columns)

serving_input_receiver_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)

Exported_Path = est.export_saved_model('saved_model', serving_input_receiver_fn)

imported = tf.saved_model.load(Exported_Path)

有关使用 Tensorflow Version 2 的完整工作代码,请找到此Github Gist