如何编写分段TensorFlow函数,即在其中包含if语句的函数?
当前代码
import tensorflow as tf
my_fn = lambda x : x ** 2 if x > 0 else x + 5
with tf.Session() as sess:
x = tf.Variable(tf.random_normal([100, 1]))
output = tf.map_fn(my_fn, x)
错误:
TypeError:不允许使用tf.Tensor
作为Python bool
。使用if t is not None:
代替if t:
来测试是否定义了张量,并使用逻辑TensorFlow操作来测试张量的值。
答案 0 :(得分:7)
tf.select
不再像这个帖子所表明的那样工作了
https://github.com/tensorflow/tensorflow/issues/8647
对我有用的是tf.where
condition = tf.greater(x, 0)
res = tf.where(condition, tf.square(x), x + 5)
答案 1 :(得分:5)
你应该看看tf.where
。
对于您的示例,您可以执行以下操作:
condition = tf.greater(x, 0)
res = tf.where(condition, tf.square(x), x + 5)
编辑:从tf.select
移至tf.where
答案 2 :(得分:1)
此处的问题是my_fn
无法检查条件x>0
,因为x
是tf.Tensor
,这意味着它只会被填充值如果启动了tensorflow会话,并且您要求运行包含x
的图形的一部分。要在图形本身中包含if-then逻辑,您必须使用tensorflow提供的操作,例如: tf.select
答案 3 :(得分:1)
使用tf.cond
res = tf.cond(tf.greater(x, 0), lambda: tf.square(x), lambda: x + 5)