TensorFlow:tf.where()出错

时间:2017-03-31 19:42:48

标签: python tensorflow

我不确定为什么tf.where()不能按计划运行。我想使用a的值yt小于5,否则使用b

tf.InteractiveSession()
yt = tf.constant([10,1,10])
a = tf.constant([1,2,3])
b = tf.constant([3,4,5])
tf.where(tf.less(yt,[5]), a, b).eval()

给出错误

where() takes at most 2 arguments (3 given)

你能告诉我为什么会收到这个错误吗?还有其他办法吗?

1 个答案:

答案 0 :(得分:4)

{TensorFlow 0.10(当它是took two arguments and returned two outputs)和TensorFlow 0.12+(现在是takes three tensor arguments and returns a single output)之间的tf.where()语法被更改了,取代了前tf.select()

作为Himaprasoon suggests,升级到最新版本的TensorFlow可以解决您的问题。