在TensorFlow中,我有一个tf.while_loop
,其中body
参数定义为以下函数:
def loop_body(step_num, x):
if step_num == 0:
x += 1
else:
x += 2
step_num = tf.add(step_num, 1)
return step_num, x
问题在于,即使step_num == 0
的初始值为True
,行step_num
也不是0
。我假设这是因为step_num
不是整数,而是实际上在循环tf.constant
之外定义的step_num = tf.constant(0)
。因此,我正在将tf.constant
与Python整数(即False
)进行比较。
该比较中应该使用什么?
答案 0 :(得分:3)
第一种方法:使用tf.cond
:
def loop_body(step_num, x):
x = tf.cond(tf.equal(step_num,0),lambda :x+1,lambda :x+2)
step_num = tf.add(step_num, 1)
return step_num, x
第二种方法:使用autograph
:
from tensorflow.contrib import autograph as ag
ag.to_graph(loop_body2)(step_num, x)
一个例子:
import tensorflow as tf
from tensorflow.contrib import autograph as ag
def loop_body(step_num, x):
x = tf.cond(tf.equal(step_num,0),lambda :x+1,lambda :x+2)
step_num = tf.add(step_num, 1)
return step_num, x
def loop_body2(step_num, x):
if step_num == 0:
x += 1
else:
x += 2
step_num = tf.add(step_num, 1)
return step_num, x
step_num = tf.constant(0)
x = tf.constant(2)
result1 = loop_body(step_num, x)
result2 = ag.to_graph(loop_body2)(step_num, x)
with tf.Session() as sess:
print(sess.run(result1))
print(sess.run(result2))
#print
(1, 3)
(1, 3)