我正在学习Tensorflow并尝试正确构建我的代码。我(或多或少)知道如何构建裸图或类方法的图形,但我试图弄清楚如何最好地构造代码。我尝试过这个简单的例子:
def build_graph():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.int8)
b = tf.add(a, tf.constant(1, dtype=tf.int8))
return g
graph = build_graph()
with tf.Session(graph=graph) as sess:
feed = {a: 3}
print(sess.run(b, feed_dict=feed))
应该只打印出来4.但是,当我这样做时,我收到错误:
Cannot interpret feed_dict key as Tensor: Tensor
Tensor("Placeholder:0", dtype=int8) is not an element of this graph.
我很确定这是因为函数build_graph
中的占位符是私有的,但with tf.Session(graph=graph)
不应该处理它吗?在这种情况下使用feed dict有更好的方法吗?
答案 0 :(得分:14)
有几种选择。
选项1 :只传递张量的名称而不是张量本身。
with tf.Session(graph=graph) as sess:
feed = {"Placeholder:0": 3}
print(sess.run("Add:0", feed_dict=feed))
在这种情况下,最好给节点提供有意义的名称,而不是使用上面的默认名称:
def build_graph():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.int8, name="a")
b = tf.add(a, tf.constant(1, dtype=tf.int8), name="b")
return g
graph = build_graph()
with tf.Session(graph=graph) as sess:
feed = {"a:0": 3}
print(sess.run("b:0", feed_dict=feed))
回想一下,名为"foo"
的操作的输出是名为"foo:0"
,"foo:1"
的张量,依此类推。大多数操作只有一个输出。
选项2 :让您的build_graph()
函数返回所有重要节点。
def build_graph():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.int8)
b = tf.add(a, tf.constant(1, dtype=tf.int8))
return g, a, b
graph, a, b = build_graph()
with tf.Session(graph=graph) as sess:
feed = {a: 3}
print(sess.run(b, feed_dict=feed))
选项3 :将重要节点添加到集合
def build_graph():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.int8)
b = tf.add(a, tf.constant(1, dtype=tf.int8))
for node in (a, b):
g.add_to_collection("important_stuff", node)
return g
graph = build_graph()
a, b = graph.get_collection("important_stuff")
with tf.Session(graph=graph) as sess:
feed = {a: 3}
print(sess.run(b, feed_dict=feed))
选项4 :根据@pohe的建议,您可以使用get_tensor_by_name()
def build_graph():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.int8, name="a")
b = tf.add(a, tf.constant(1, dtype=tf.int8), name="b")
return g
graph = build_graph()
a, b = [graph.get_tensor_by_name(name) for name in ("a:0", "b:0")]
with tf.Session(graph=graph) as sess:
feed = {a: 3}
print(sess.run(b, feed_dict=feed))
我个人经常使用选项2,它非常简单并且不需要玩名字。当图表很大并且会长时间存在时我使用选项3,因为集合与模型一起保存,并且它是记录真正重要内容的快速方法。我不是真的使用选项1,因为我更喜欢实际引用对象(不知道为什么)。当您使用由其他人构建的图表时,选项4非常有用,并且他们没有直接引用张量。
希望这有帮助!
答案 1 :(得分:1)
我也在寻找更好的方式,所以我的回答可能不是最好的。不过,如果您给a
和b
一个名称,例如
a = tf.placeholder(tf.int8, name='a')
b = tf.add(a, tf.constant(1, dtype=tf.int8), name='b')
然后你可以做
graph = build_graph()
a = graph.get_tensor_by_name('a:0')
b = graph.get_tensor_by_name('b:0')
with tf.Session(graph=graph) as sess:
feed = {a: 3}
print(sess.run(b, feed_dict=feed))
P.S。命名a
和b
不是必需的。它以后更容易引用。此外,如果您找到了更好的解决方案,请分享。