我试图在Tensorflow中执行以下操作 -
import tensorflow as tf
graph = tf.Graph()
with graph.as_default():
i = tf.Variable(0)
sol = tf.Variable(0)
def cond(i, sol):
return tf.less(i, 2)
def body(i, sol):
i = tf.add(i, 1)
sol = tf.add(sol, 1)
tf.while_loop(cond, body, [i, sol])
with tf.Session(graph=graph) as session:
tf.global_variables_initializer().run()
result = session.run(sol, feed_dict={})
print result
我无法理解这两种结构是什么'在错误消息中。我想最终制作一个&tff.while_loop'与条件'基于tf.Placeholder的值(' i'在上面的代码中)。
答案 0 :(得分:3)
您应该将return
添加到body
功能:
def body(i, sol):
i = tf.add(i, 1)
sol = tf.add(sol, 1)
retrun [i, sol]
但我认为您还应该将代码更改为
graph = tf.Graph()
with graph.as_default():
i = tf.Variable(0)
sol = tf.Variable(0)
def cond(i, sol):
return tf.less(i, 2)
def body(i, sol):
i = tf.add(i, 1)
sol = tf.add(sol, 1)
return [i, sol]
result = tf.while_loop(cond, body, [i, sol])
with tf.Session(graph=graph) as session:
tf.global_variables_initializer().run()
result = session.run(result, feed_dict={})
print(result[1])
因为tf.while
只是图表中的节点,您应该运行该节点,否则您将无法获得任何结果。