将“ tf.constant”与整数进行比较

时间:2018-12-18 13:33:35

标签: python tensorflow

在TensorFlow中,我有一个tf.while_loop,其中body参数定义为以下函数:

def loop_body(step_num, x):
    if step_num == 0:
        x += 1
    else:
        x += 2
    step_num = tf.add(step_num, 1)
    return step_num, x

问题在于,即使step_num == 0的初始值为True,行step_num也不是0。我假设这是因为step_num不是整数,而是实际上在循环tf.constant之外定义的step_num = tf.constant(0)。因此,我正在将tf.constant与Python整数(即False)进行比较。

该比较中应该使用什么?

1 个答案:

答案 0 :(得分:3)

第一种方法:使用tf.cond

def loop_body(step_num, x):
    x = tf.cond(tf.equal(step_num,0),lambda :x+1,lambda :x+2)
    step_num = tf.add(step_num, 1)
    return step_num, x

第二种方法:使用autograph

from tensorflow.contrib import autograph as ag
ag.to_graph(loop_body2)(step_num, x)

一个例子:

import tensorflow as tf
from tensorflow.contrib import autograph as ag

def loop_body(step_num, x):
    x = tf.cond(tf.equal(step_num,0),lambda :x+1,lambda :x+2)
    step_num = tf.add(step_num, 1)
    return step_num, x

def loop_body2(step_num, x):
    if step_num == 0:
        x += 1
    else:
        x += 2
    step_num = tf.add(step_num, 1)
    return step_num, x

step_num = tf.constant(0)
x = tf.constant(2)
result1 = loop_body(step_num, x)
result2 = ag.to_graph(loop_body2)(step_num, x)

with tf.Session() as sess:
    print(sess.run(result1))
    print(sess.run(result2))

#print 
(1, 3)
(1, 3)