我们如何保存由contrib.learn.Classifier制作的Tensorflow模型?

时间:2017-08-17 19:38:01

标签: tensorflow

我想保存一个由contrib.learn.Classifier创建的模型,但我不知道如何引用它的内部节点。这是我在vanilla Tensorflow模型中使用的代码(y = W * x + b),它运行良好。

W = tf.Variable([], dtype=tf.float32)
b = tf.Variable([], dtype=tf.float32)
x = tf.placeholder(tf.float32, name="x")
my_model = tf.add(W * x, b, name="model")
...  # training
builder = tf.saved_model.builder.SavedModelBuilder("/tmp/saved_model")
builder.add_meta_graph_and_variables(sess, ["predict_tag"], signature_def_map= {
          "model": tf.saved_model.signature_def_utils.predict_signature_def(
              inputs= {"x": x},
              outputs= {"model": my_model})
          })
builder.save()

现在,如果我使用contrib.learn.Classifier

estimator = tf.contrib.learn.LinearClassifier(feature_columns=feature_columns)
estimator.fit(input_fn=train_input_fn, steps=1000)

对于后者builder,我如何同样使用上面的estimator?请注意,我不想做tf.train.Saver().save(sess, "/tmp/model");使用saved_model.builder是必需的。谢谢!

1 个答案:

答案 0 :(得分:0)

您可以使用估算工具的export_savedmodel函数将推理图表作为SavedModel导出到给定目录。tf.contrib.learn.LinearClassifier

from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib
from tensorflow.contrib.learn.python.learn.utils.input_fn_utils import build_parsing_serving_input_fn

# create feature specs from feature columns
feature_spec = feature_column_lib.create_feature_spec_for_parsing(
  feature_columns)

# create the input function 
serving_input_fn = build_parsing_serving_input_fn(feature_spec)

# finally save the model
estimator.export_savedmodel('/path/to/save/my_model/', serving_input_fn=input_receiver_fn)