我正在尝试将布尔值feed_dict传递给函数
def sum(a, b, flag = True, msg1= "Sum", msg2= "Multiply "):
if (flag is True):
print(msg1)
vtotal = tf.add(a,b)
else:
print(msg2)
vtotal = tf.multiply(a,b)
return vtotal
当我将函数称为sum(a,b)时,flag = True的默认值用于处理
但是当我将函数调用为
时sum(a, b, flag):
我像feed一样从feed_dict提供了标志的值
output = sess.run(total,feed_dict = {a: a_arr, b: b_arr, flag: True})
它不将值视为True,而是执行函数的else部分
完整代码如下:请帮助为什么会发生这种情况。
def initialize_placeholders():
a = tf.placeholder(tf.float32,[3,None],name="a")
b = tf.placeholder(tf.float32,[3,None],name ="b")
flag = tf.placeholder(tf.bool, name="flag")
return a, b, flag
def sum(a, b, flag = True, msg1= "Sum", msg2= "Multiply "):
if (flag is True):
print(msg1)
vtotal = tf.add(a,b)
else:
print(msg2)
vtotal = tf.multiply(a,b)
return vtotal
def model(a_arr,b_arr):
#print(a_arr)
#print(b_arr)
tf.reset_default_graph()
a, b ,flag= initialize_placeholders()
total = sum(a,b,flag)
init = tf.global_variables_initializer()
print(flag)
with tf.Session() as sess:
sess.run(init)
output = sess.run(total,feed_dict = {a: a_arr, b: b_arr, flag: True})
print(flag)
unv = sess.run(tf.report_uninitialized_variables())
sess.close()
return output, unv
a_arr = np.arange(6)
a_arr = a_arr.reshape(3,2)
b_arr = np.array([2,4,6,8,10,12])
b_arr = b_arr.reshape(3,2)
output , unv = model(a_arr,b_arr)
print(output)
print(unv)
答案 0 :(得分:1)
您不能在常规条件Python语句中使用TensorFlow值(除非您使用的是AutoGraph之类的东西)。您可以使用tf.cond
来完成自己想要的事情:
def sum(a, b, flag=True):
flag = tf.convert_to_tensor(flag)
return tf.cond(flag, lambda: tf.add(a, b), lambda: tf.multiply(a, b))
为了预先保存flag
的值来保存tf.cond
操作,您还可以使其更加复杂。例如,您可能会遇到这样的情况:
def sum(a, b, flag = True, msg1= "Sum", msg2= "Multiply "):
true_fn = lambda: tf.add(a, b)
false_fn = lambda: tf.multiply(a, b)
if flag is True:
return true_fn()
elif flag is False:
return false_fn()
else: # Use TensorFlow conditional
flag = tf.convert_to_tensor(flag)
return tf.cond(flag, true_fn, false_fn)
我删除了print
指令,因为它们不能直接在TensorFlow条件语句中使用,但是如果要在执行图形时查看打印的消息,您仍然可以进行tf.print
操作。