如何在“用于分类的递归神经网络”教程中导出保存的模型

时间:2018-08-01 03:48:18

标签: tensorflow tensorflow-serving

我正在阅读tensorflow的教程,并且在如何保存经过训练的模型上遇到了问题。

在本教程中,定义了递归神经网络并对其进行了训练以进行图纸分类。这是对应的代码:

estimator = tf.estimator.Estimator(
      model_fn=model_fn,
      model_dir=output_dir,
      config=config,
      params=model_params)
  # Train the model.
  tf.contrib.learn.Experiment(
      estimator=estimator,
      train_input_fn=get_input_fn(
          mode=tf.contrib.learn.ModeKeys.TRAIN,
          tfrecord_pattern=FLAGS.training_data,
          batch_size=FLAGS.batch_size),
      train_steps=FLAGS.steps,
      eval_input_fn=get_input_fn(
          mode=tf.contrib.learn.ModeKeys.EVAL,
          tfrecord_pattern=FLAGS.eval_data,
          batch_size=FLAGS.batch_size),
      min_eval_frequency=1000)

tutorials没有提供用于显示如何导出和保存模型的代码。我怎样才能做到这一点?

1 个答案:

答案 0 :(得分:0)

本教程利用了Estimator API。训练模型后,可以通过调用export_savedmodel()方法来保存模型:

export_dir = './' # path to store the model
estimator.export_savedmodel(export_dir, serving_input_fn)

serving_input_fn等于培训期间input_fn的服务时间。此函数应返回一个ServingInputReceiver对象。该对象的目的是接收服务请求,对其进行解析,然后将其发送到模型以进行推断。要进行解析,您需要提供feature_spec字典,该字典告诉解析函数期望哪些功能。从文档中:

feature_spec = {'foo': tf.FixedLenFeature(...),
                'bar': tf.VarLenFeature(...)}

有关如何从头开始构建它的详细说明,请参见TF documentation

在大多数情况下,您可以使用build_parsing_serving_input_receiver_fnbuild_raw_serving_input_receiver_fn实用程序功能来构建serving_input_fn。解析接收器需要如上所述的功能规范,原始接收器需要从字符串到张量的映射,并将允许您将“原始”(非序列化)输入数据作为请求传递给模型。例如:

serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
feature_spec,
default_batch_size=None)