sess = tf.InteractiveSession()
t = tf.expand_dims(tf.constant(list(range(9))), axis=1)
tf.where(t == 5).eval()
InvalidArgumentError (see above for traceback): WhereOp : Unhandled input dimensions: 0
[[Node: Where_16 = Where[T=DT_BOOL, _device="/job:localhost/replica:0/task:0/device:CPU:0"](Where_16/condition)]]
这是怎么回事? Numpy中带有np.where的相应代码。
答案 0 :(得分:0)
在您的示例中,您正在评估tf.where(False)
,因为==
运算符不会因张量而过载。 (更多信息,例如:TensorFlow operator overloading)
尝试:
sess = tf.InteractiveSession()
t = tf.expand_dims(tf.constant(list(range(9))), axis=1)
tf.where(tf.equal(t, 5)).eval()