我正在阅读tensorflow的教程,并且在如何保存经过训练的模型上遇到了问题。
在本教程中,定义了递归神经网络并对其进行了训练以进行图纸分类。这是对应的代码:
estimator = tf.estimator.Estimator(
model_fn=model_fn,
model_dir=output_dir,
config=config,
params=model_params)
# Train the model.
tf.contrib.learn.Experiment(
estimator=estimator,
train_input_fn=get_input_fn(
mode=tf.contrib.learn.ModeKeys.TRAIN,
tfrecord_pattern=FLAGS.training_data,
batch_size=FLAGS.batch_size),
train_steps=FLAGS.steps,
eval_input_fn=get_input_fn(
mode=tf.contrib.learn.ModeKeys.EVAL,
tfrecord_pattern=FLAGS.eval_data,
batch_size=FLAGS.batch_size),
min_eval_frequency=1000)
tutorials没有提供用于显示如何导出和保存模型的代码。我怎样才能做到这一点?
答案 0 :(得分:0)
本教程利用了Estimator
API。训练模型后,可以通过调用export_savedmodel()
方法来保存模型:
export_dir = './' # path to store the model
estimator.export_savedmodel(export_dir, serving_input_fn)
serving_input_fn
等于培训期间input_fn
的服务时间。此函数应返回一个ServingInputReceiver
对象。该对象的目的是接收服务请求,对其进行解析,然后将其发送到模型以进行推断。要进行解析,您需要提供feature_spec
字典,该字典告诉解析函数期望哪些功能。从文档中:
feature_spec = {'foo': tf.FixedLenFeature(...),
'bar': tf.VarLenFeature(...)}
有关如何从头开始构建它的详细说明,请参见TF documentation。
在大多数情况下,您可以使用build_parsing_serving_input_receiver_fn
或build_raw_serving_input_receiver_fn
实用程序功能来构建serving_input_fn
。解析接收器需要如上所述的功能规范,原始接收器需要从字符串到张量的映射,并将允许您将“原始”(非序列化)输入数据作为请求传递给模型。例如:
serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
feature_spec,
default_batch_size=None)