从feed_dict将bool传递给函数不起作用

时间:2019-07-11 17:34:21

标签: python tensorflow

我正在尝试将布尔值feed_dict传递给函数

def sum(a, b, flag = True, msg1= "Sum", msg2= "Multiply "):

    if (flag is True):
        print(msg1)
        vtotal = tf.add(a,b)
    else:
        print(msg2)
        vtotal = tf.multiply(a,b)

    return vtotal

当我将函数称为sum(a,b)时,flag = True的默认值用于处理

但是当我将函数调用为

sum(a, b, flag):

我像feed一样从feed_dict提供了标志的值

output = sess.run(total,feed_dict = {a: a_arr, b: b_arr, flag: True})

它不将值视为True,而是执行函数的else部分

完整代码如下:请帮助为什么会发生这种情况。

def initialize_placeholders():
    a = tf.placeholder(tf.float32,[3,None],name="a")
    b = tf.placeholder(tf.float32,[3,None],name ="b")
    flag = tf.placeholder(tf.bool, name="flag")

    return a, b, flag

def sum(a, b, flag = True, msg1= "Sum", msg2= "Multiply "):

    if (flag is True):
        print(msg1)
        vtotal = tf.add(a,b)
    else:
        print(msg2)
        vtotal = tf.multiply(a,b)

    return vtotal

def model(a_arr,b_arr):
    #print(a_arr)
    #print(b_arr)
    tf.reset_default_graph()
    a, b ,flag= initialize_placeholders()
    total = sum(a,b,flag)

    init = tf.global_variables_initializer()
    print(flag)

    with tf.Session() as sess:
        sess.run(init)
        output = sess.run(total,feed_dict = {a: a_arr, b: b_arr, flag: True})
        print(flag)
        unv = sess.run(tf.report_uninitialized_variables())
        sess.close()
    return output, unv

a_arr = np.arange(6)
a_arr = a_arr.reshape(3,2)
b_arr = np.array([2,4,6,8,10,12])
b_arr = b_arr.reshape(3,2)
output , unv = model(a_arr,b_arr)
print(output)
print(unv)

1 个答案:

答案 0 :(得分:1)

您不能在常规条件Python语句中使用TensorFlow值(除非您使用的是AutoGraph之类的东西)。您可以使用tf.cond来完成自己想要的事情:

def sum(a, b, flag=True):
    flag = tf.convert_to_tensor(flag)
    return tf.cond(flag, lambda: tf.add(a, b), lambda: tf.multiply(a, b))

为了预先保存flag的值来保存tf.cond操作,您还可以使其更加复杂。例如,您可能会遇到这样的情况:

def sum(a, b, flag = True, msg1= "Sum", msg2= "Multiply "):
    true_fn = lambda: tf.add(a, b)
    false_fn = lambda: tf.multiply(a, b)
    if flag is True:
        return true_fn()
    elif flag is False:
        return false_fn()
    else:  # Use TensorFlow conditional
        flag = tf.convert_to_tensor(flag)
        return tf.cond(flag, true_fn, false_fn)

我删除了print指令,因为它们不能直接在TensorFlow条件语句中使用,但是如果要在执行图形时查看打印的消息,您仍然可以进行tf.print操作。