假设我有两个不同的图表: 第一个包含x和y: x = tf.placeholder(tf.float32,shape =(1)), y = 2 * x, 第二个包含a和b: a = tf.placeholder(tf.float32,shape =(1)), b = 2 * x。
现在,我想通过在y和a之间添加一些“身份链接”来连接这两个图。换句话说,我想告诉第二个图从第一个图(y)中的某个节点获取其输入(a)。在没有代码重新创建第二个图形的情况下,它很方便,您只需从某个地方对其进行反序列化。一种方法是使用Session.run计算第一个图的输出,然后将其提供给计算第二个图的输出的Session.run调用,但必须有一些干净的方法。
谢谢!
答案 0 :(得分:1)
如果我理解正确,这对你有用吗?
它利用tf.import_graph_def
作业
我们有x
,然后输入第一个图表以获取y = 2 *x
,
然后我们将y
提供给第二个图表以获取b = 2 * y
,对于x = 1.0
,以下代码将生成4.0
。
import tensorflow as tf
FLOAT = tf.float32
tf.reset_default_graph()
def graph_1():
g = tf.Graph()
with g.as_default():
x = tf.placeholder(FLOAT, [], name='x')
y = tf.multiply(2.0, x, name='y')
return g
def graph_2():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(FLOAT, [], name='a')
b = tf.multiply(2.0, a, name='b')
return g
# x = 1.0
x = tf.constant(1.0, FLOAT, [])
# feed x to graph_1 -> y = 2.0
g1 = graph_1()
[g1_y] = tf.import_graph_def(g1.as_graph_def(), input_map={'x': x}, return_elements=['y:0'])
# feed y to graph_2 -> b = 4.0
g2 = graph_2()
[g2_b] = tf.import_graph_def(g2.as_graph_def(), input_map={'a': g1_y}, return_elements=['b:0'])
with tf.Session() as sess:
print(sess.run([g2_b]))
笔记本:https://gist.github.com/phizaz/21a5454ddc6c2a15c5c0eae91c96cda5
顺便说一句,如果graph_1
或graph_2
包含"变量"这不起作用...我不知道如何初始化那些基础变量到目前为止,任何建议?