连接两个图表

时间:2016-05-12 20:57:14

标签: tensorflow

假设我有两个不同的图表: 第一个包含x和y: x = tf.placeholder(tf.float32,shape =(1)), y = 2 * x, 第二个包含a和b: a = tf.placeholder(tf.float32,shape =(1)), b = 2 * x。

现在,我想通过在y和a之间添加一些“身份链接”来连接这两个图。换句话说,我想告诉第二个图从第一个图(y)中的某个节点获取其输入(a)。在没有代码重新创建第二个图形的情况下,它很方便,您只需从某个地方对其进行反序列化。一种方法是使用Session.run计算第一个图的输出,然后将其提供给计算第二个图的输出的Session.run调用,但必须有一些干净的方法。

谢谢!

1 个答案:

答案 0 :(得分:1)

如果我理解正确,这对你有用吗?

它利用tf.import_graph_def作业

我们有x,然后输入第一个图表以获取y = 2 *x, 然后我们将y提供给第二个图表以获取b = 2 * y,对于x = 1.0,以下代码将生成4.0

import tensorflow as tf
FLOAT = tf.float32
tf.reset_default_graph()

def graph_1():
    g = tf.Graph()
    with g.as_default():
        x = tf.placeholder(FLOAT, [], name='x')
        y = tf.multiply(2.0, x, name='y')
    return g

def graph_2():
    g = tf.Graph()
    with g.as_default():
        a = tf.placeholder(FLOAT, [], name='a')
        b = tf.multiply(2.0, a, name='b')
    return g

# x = 1.0
x = tf.constant(1.0, FLOAT, [])
# feed x to graph_1 -> y = 2.0
g1 = graph_1()
[g1_y] = tf.import_graph_def(g1.as_graph_def(), input_map={'x': x}, return_elements=['y:0'])
# feed y to graph_2 -> b = 4.0
g2 = graph_2()
[g2_b] = tf.import_graph_def(g2.as_graph_def(), input_map={'a': g1_y}, return_elements=['b:0'])

with tf.Session() as sess:
    print(sess.run([g2_b]))

笔记本:https://gist.github.com/phizaz/21a5454ddc6c2a15c5c0eae91c96cda5

顺便说一句,如果graph_1graph_2包含"变量"这不起作用...我不知道如何初始化那些基础变量到目前为止,任何建议?