是否可以为Tensorflow节点设置别名?

时间:2017-03-15 10:13:16

标签: variables namespaces tensorflow alias serving

我有一个复杂的网络,为此,我创建了一个非常简单的类,用于在将模型序列化为冻结图形文件后对其进行推理。

问题是在这个文件中我需要用他的命名空间加载变量,这可能最终取决于我如何构建模型。在我的情况下,这样结束:

# Load the input and output node with a singular namespace depending on the model
self.input_node = self.sess.graph.get_tensor_by_name("create_model/mymodelname/input_node:0")
self.output_node = self.sess.graph.get_tensor_by_name("create_model/mymodelname/out/output_node:0")

我想在将这两个节点存储到模型中之前给它们一个别名,例如它们最终会有一个通用名称,然后我的推理类可以用作获取模型的通用类。在这种情况下,我最终会做这样的事情:

# Load the input and output node with a general node namespace
self.input_node = self.sess.graph.get_tensor_by_name("input_node:0")
self.output_node = self.sess.graph.get_tensor_by_name("output_node:0")

那么给他们别名有什么选择吗?我真的没找到任何东西。

非常感谢!

1 个答案:

答案 0 :(得分:2)

您可以使用tf.identity作为输出。

output_node = sess.graph.get_tensor_by_name("create_model/mymodelname/output_node:0")
tf.identity(output_node, name="output_node")

将创建一个名为“output_node”的新passthrough op,并将从您指定的节点获取其值。

输入有点棘手,您需要更改构建模型的方式 - 例如让它从外部获取输入,然后创建一个具有固定名称的输入占位符,然后将其传递给函数构建模型。