我不确定为什么tf.where()不能按计划运行。我想使用a
的值yt
小于5,否则使用b
。
tf.InteractiveSession()
yt = tf.constant([10,1,10])
a = tf.constant([1,2,3])
b = tf.constant([3,4,5])
tf.where(tf.less(yt,[5]), a, b).eval()
给出错误
where() takes at most 2 arguments (3 given)
你能告诉我为什么会收到这个错误吗?还有其他办法吗?
答案 0 :(得分:4)
{TensorFlow 0.10(当它是took two arguments and returned two outputs)和TensorFlow 0.12+(现在是takes three tensor arguments and returns a single output)之间的tf.where()
语法被更改了,取代了前tf.select()
。
作为Himaprasoon suggests,升级到最新版本的TensorFlow可以解决您的问题。