在函数内部构建张量流图

时间:2017-06-07 16:42:47

标签: python tensorflow structure

我正在学习Tensorflow并尝试正确构建我的代码。我(或多或少)知道如何构建裸图或类方法的图形,但我试图弄清楚如何最好地构造代码。我尝试过这个简单的例子:

def build_graph():                
     g = tf.Graph()     
     with g.as_default():                       
         a = tf.placeholder(tf.int8)
         b = tf.add(a, tf.constant(1, dtype=tf.int8))
     return g   

graph = build_graph()
with tf.Session(graph=graph) as sess:
     feed = {a: 3}      
     print(sess.run(b, feed_dict=feed))

应该只打印出来4.但是,当我这样做时,我收到错误:

Cannot interpret feed_dict key as Tensor: Tensor 
Tensor("Placeholder:0", dtype=int8) is not an element of this graph.

我很确定这是因为函数build_graph中的占位符是私有的,但with tf.Session(graph=graph)不应该处理它吗?在这种情况下使用feed dict有更好的方法吗?

2 个答案:

答案 0 :(得分:14)

有几种选择。

选项1 :只传递张量的名称而不是张量本身。

with tf.Session(graph=graph) as sess:
    feed = {"Placeholder:0": 3}      
    print(sess.run("Add:0", feed_dict=feed))

在这种情况下,最好给节点提供有意义的名称,而不是使用上面的默认名称:

def build_graph():
     g = tf.Graph()
     with g.as_default():
         a = tf.placeholder(tf.int8, name="a")
         b = tf.add(a, tf.constant(1, dtype=tf.int8), name="b")
     return g

graph = build_graph()
with tf.Session(graph=graph) as sess:
     feed = {"a:0": 3}
     print(sess.run("b:0", feed_dict=feed))

回想一下,名为"foo"的操作的输出是名为"foo:0""foo:1"的张量,依此类推。大多数操作只有一个输出。

选项2 :让您的build_graph()函数返回所有重要节点。

def build_graph():
     g = tf.Graph()
     with g.as_default():
         a = tf.placeholder(tf.int8)
         b = tf.add(a, tf.constant(1, dtype=tf.int8))
     return g, a, b

graph, a, b = build_graph()
with tf.Session(graph=graph) as sess:
     feed = {a: 3}
     print(sess.run(b, feed_dict=feed))

选项3 :将重要节点添加到集合

def build_graph():
     g = tf.Graph()
     with g.as_default():
         a = tf.placeholder(tf.int8)
         b = tf.add(a, tf.constant(1, dtype=tf.int8))
     for node in (a, b):
         g.add_to_collection("important_stuff", node)
     return g

graph = build_graph()
a, b = graph.get_collection("important_stuff")
with tf.Session(graph=graph) as sess:
     feed = {a: 3}
     print(sess.run(b, feed_dict=feed))

选项4 :根据@pohe的建议,您可以使用get_tensor_by_name()

def build_graph():
     g = tf.Graph()
     with g.as_default():
         a = tf.placeholder(tf.int8, name="a")
         b = tf.add(a, tf.constant(1, dtype=tf.int8), name="b")
     return g

graph = build_graph()
a, b = [graph.get_tensor_by_name(name) for name in ("a:0", "b:0")]
with tf.Session(graph=graph) as sess:
     feed = {a: 3}
     print(sess.run(b, feed_dict=feed))

我个人经常使用选项2,它非常简单并且不需要玩名字。当图表很大并且会长时间存在时我使用选项3,因为集合与模型一起保存,并且它是记录真正重要内容的快速方法。我不是真的使用选项1,因为我更喜欢实际引用对象(不知道为什么)。当您使用由其他人构建的图表时,选项4非常有用,并且他们没有直接引用张量。

希望这有帮助!

答案 1 :(得分:1)

我也在寻找更好的方式,所以我的回答可能不是最好的。不过,如果您给ab一个名称,例如

a = tf.placeholder(tf.int8, name='a')
b = tf.add(a, tf.constant(1, dtype=tf.int8), name='b')

然后你可以做

graph = build_graph()

a = graph.get_tensor_by_name('a:0')
b = graph.get_tensor_by_name('b:0')

with tf.Session(graph=graph) as sess:
    feed = {a: 3}      
    print(sess.run(b, feed_dict=feed))

P.S。命名ab不是必需的。它以后更容易引用。此外,如果您找到了更好的解决方案,请分享。