在tensorflow 1.14中,很明显tf.compat.v1.train.init_from_checkpoint
可以加载ckpt
来继续训练(或热启动)。但是,我在SavedModel中找不到任何相应的方法,并且tf.estimator.WarmStartSetting也仅支持ckpt
。对我来说很奇怪,因为this answer提到应该在SavedModel
中存储一个检查点。有谁知道:
答案 0 :(得分:0)
为了加载SavedModel以继续训练,可以使用tf.saved_model.loader.load,如下所示:
graph = tf.Graph()
with tf.Session(graph=graph) as sess:
tf.saved_model.loader.load(sess, [tag_constants.SERVING], saved_model_location)
为了馈送新的输入数据,您可以获取输入张量名称,如下所示:
signature_def = meta_graph_def.signature_def[tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
inputs = [v.name for v in signature_def.inputs.values()]
input_tensors = [node.split(":")[0] for node in inputs]
然后,您可以制作一些feed_dict
,以将新的输入馈送到输入张量。获取输出张量的方法可以类似于我上面概述的方法。