我有三个数组,X
,Y
和Z
。如果res
中的相应元素为真,我想放入X
和Z
的元素;否则,我会从Y
中添加一个元素。
我实现了这样:
X = tf.constant([[1, 2], [3, 4]])
Y = tf.constant([[5, 6], [7, 8]])
Z = tf.constant([[True, False], [False, True]], tf.bool)
res = tf.where(Z, X, Y)
print(res.eval())
但是,我收到此错误:
TypeError: where() takes from 1 to 2 positional arguments but 3 were given
我查看了来自here的tf.where
的定义,我的用法似乎很好。
知道可能是什么问题吗?
答案 0 :(得分:1)
我怀疑您使用的是旧版TensorFlow:
e.g。在r0.10 tf.where
中,过去只有2个参数。
tf.where(input, name=None)