假设我用以下代码保存了一个模型
tf.saved_model.simple_save(sess, export_dir, in={'input_x': x, 'input_y':y}, out={'output_z':z})
现在我将保存的模型加载回另一个python程序中
with tf.Session() as sess:
tf.saved_model.loader.load(sess, ['serve'], export_dir)
现在的问题是,当调用simple_save()方法时,如何在输入/输出参数中指定的“ input_x”,“ input_y”,“ output_z”键抓住x,y,z张量的句柄? / p>
我在网上找到的唯一解决方案是在创建x,y,z张量时显式命名,然后使用这些名称从图中检索它们,这似乎是多余的,因为我们在调用时为它们指定了键simple_save()。
答案 0 :(得分:2)
我确实遇到了您的问题,经过一番调查(我认为TF文档不佳),我找到了下一个解决方案:
使用返回的MetaGraphDef对象查找输入\输出名称映射。
graph = tf.Graph()
with graph.as_default():
metagraph = tf.saved_model.loader.load(sess, [tag_constants.SERVING],save_path)
inputs_mapping = dict(metagraph.signature_def['serving_default'].inputs)
outputs_mapping = dict(metagraph.signature_def['serving_default'].outputs)
此代码将为您提供保存到“ TensorInfo”对象时提供的名称之间的映射,您可以从他轻松获得映射的张量名称,例如:
my_input = inputs_mapping['my_input_name'].name
my_input_t = graph.get_tensor_by_name(my_input)
答案 1 :(得分:0)
tf.saved_model.loader.load
的返回值是一个MetaGraphDef
协议缓冲区,该缓冲区应具有保存保存的模型时设置的所有签名;这些应该包含您想要的名称。