在TensorFlow中编写分段函数,然后在TensorFlow中编写

时间:2016-06-23 00:51:20

标签: python tensorflow

如何编写分段TensorFlow函数,即在其中包含if语句的函数?

当前代码

import tensorflow as tf

my_fn = lambda x : x ** 2 if x > 0 else x + 5
with tf.Session() as sess:
  x = tf.Variable(tf.random_normal([100, 1]))
  output = tf.map_fn(my_fn, x)

错误:

TypeError:不允许使用tf.Tensor作为Python bool。使用if t is not None:代替if t:来测试是否定义了张量,并使用逻辑TensorFlow操作来测试张量的值。

4 个答案:

答案 0 :(得分:7)

tf.select不再像这个帖子所表明的那样工作了 https://github.com/tensorflow/tensorflow/issues/8647

对我有用的是tf.where

condition = tf.greater(x, 0)
res = tf.where(condition, tf.square(x), x + 5)

答案 1 :(得分:5)

你应该看看tf.where

对于您的示例,您可以执行以下操作:

condition = tf.greater(x, 0)
res = tf.where(condition, tf.square(x), x + 5)

编辑:从tf.select移至tf.where

答案 2 :(得分:1)

此处的问题是my_fn无法检查条件x>0,因为xtf.Tensor,这意味着它只会被填充值如果启动了tensorflow会话,并且您要求运行包含x的图形的一部分。要在图形本身中包含if-then逻辑,您必须使用tensorflow提供的操作,例如: tf.select

答案 3 :(得分:1)

使用tf.cond

实现此目标的最简单方法
res = tf.cond(tf.greater(x, 0), lambda: tf.square(x), lambda: x + 5)