where()取1到2个位置参数,但给出了3个

时间:2017-06-06 18:27:52

标签: python tensorflow

我有三个数组,XYZ。如果res中的相应元素为真,我想放入XZ的元素;否则,我会从Y中添加一个元素。

我实现了这样:

X = tf.constant([[1, 2], [3, 4]])
Y = tf.constant([[5, 6], [7, 8]])
Z = tf.constant([[True, False], [False, True]], tf.bool)
res = tf.where(Z, X, Y)
print(res.eval())

但是,我收到此错误:

TypeError: where() takes from 1 to 2 positional arguments but 3 were given

我查看了来自heretf.where的定义,我的用法似乎很好。

知道可能是什么问题吗?

1 个答案:

答案 0 :(得分:1)

我怀疑您使用的是旧版TensorFlow:

e.g。在r0.10 tf.where中,过去只有2个参数。

tf.where(input, name=None)

https://www.tensorflow.org/versions/r0.10/api_docs/python/math_ops/sequence_comparison_and_indexing#where