Tf.where无法评估

时间:2018-09-02 09:04:57

标签: numpy tensorflow

sess = tf.InteractiveSession()
t = tf.expand_dims(tf.constant(list(range(9))), axis=1)
tf.where(t == 5).eval()

InvalidArgumentError (see above for traceback): WhereOp : Unhandled input dimensions: 0
     [[Node: Where_16 = Where[T=DT_BOOL, _device="/job:localhost/replica:0/task:0/device:CPU:0"](Where_16/condition)]]

这是怎么回事? Numpy中带有np.where的相应代码。

1 个答案:

答案 0 :(得分:0)

在您的示例中,您正在评估tf.where(False),因为==运算符不会因张量而过载。 (更多信息,例如:TensorFlow operator overloading

尝试:

sess = tf.InteractiveSession()
t = tf.expand_dims(tf.constant(list(range(9))), axis=1)
tf.where(tf.equal(t, 5)).eval()