如何使用saved_model API定期保存tensorflow模型?

时间:2018-03-14 20:36:20

标签: python tensorflow

因此,出于各种原因(例如语言独立性),我想使用tensorflow的saved_model API来保存/加载模型。我可以在训练结束时调用builder.add_meta_graph_and_variables()来保存所有内容(并成功恢复),但我没有看到任何定期保存的方法。关于此的Tensorflow文档非常稀疏,它们提供的模板代码(here)对我没有帮助:

...
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)

with tf.Session(graph=tf.Graph()) as sess:
  ...
  builder.add_meta_graph_and_variables(sess,
                                  ["foo-tag"],
                                  signature_def_map=foo_signatures,
                                  assets_collection=foo_assets)
...

with tf.Session(graph=tf.Graph()) as sess:
  ...
  builder.add_meta_graph(["bar-tag", "baz-tag"])
...

builder.save()

调用builder.save()不会将新变量保存到模型中。它只是更新模型protobuf。

我错过了什么?如何保存,例如第n个纪元使用saved_model

2 个答案:

答案 0 :(得分:0)

好吧,在查看张量流代码here和其他地方后,看起来答案是"你不能"#34;。 SavedModelBuilder实际上只是为训练阶段之外的模型设计的,它允许您添加元图并选择要加载/保存的变量集(即TRAINING与SERVING),但这样做。例如,SavedModelBuilder.add_meta_graph_and_variables只能被调用一次,并且没有SavedModelBuilder.update_variables或类似的东西。另一方面,在训练时,您需要使用Saver类并保存检查点和相关文件。为什么没有一个统一的系统,我不知道,但显然是这样的。

答案 1 :(得分:0)

调用builder.add_meta_graph_and_variables()时会保存变量,因为在其内部调用了saver.save()See here

解决方案:

只需在saver.save(sess, export_dir+'/variables/variables', write_meta_graph=False, write_state=False)之前致电builder.save()