继续训练SavedModel或从SavedModel加载检查点

时间:2019-10-07 07:37:35

标签: tensorflow tensorflow-serving tensorflow-estimator

在tensorflow 1.14中,很明显tf.compat.v1.train.init_from_checkpoint可以加载ckpt来继续训练(或热启动)。但是,我在SavedModel中找不到任何相应的方法,并且tf.estimator.WarmStartSetting也仅支持ckpt。对我来说很奇怪,因为this answer提到应该在SavedModel中存储一个检查点。有谁知道:

  1. 如何在SavedModel中加载检查点?或
  2. 如何在SavedModel上进行热身培训?

1 个答案:

答案 0 :(得分:0)

为了加载SavedModel以继续训练,可以使用tf.saved_model.loader.load,如下所示:

graph = tf.Graph()
with tf.Session(graph=graph) as sess:
  tf.saved_model.loader.load(sess, [tag_constants.SERVING], saved_model_location)

为了馈送新的输入数据,您可以获取输入张量名称,如下所示:

signature_def = meta_graph_def.signature_def[tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
inputs = [v.name for v in signature_def.inputs.values()]
input_tensors = [node.split(":")[0] for node in inputs]

然后,您可以制作一些feed_dict,以将新的输入馈送到输入张量。获取输出张量的方法可以类似于我上面概述的方法。