我最近开始使用tensorflow并使用tf.where()函数。我注意到,每当我使用“ ==”条件时,它都会引发错误。例如,当我尝试以下操作时:
t = tf.constant([[1, 2, 3],
[4, 5, 6]])
t2 = tf.where(t==2)
t3 = tf.gather_nd(t,t2)
t3_shape= tf.shape(t)[0]
with tf.Session() as sess:
print(sess.run([t3]))
它引发以下错误:
InvalidArgumentError:WhereOp:未处理的输入尺寸:0
有人可以在这里解释什么错误吗? 预先感谢!
答案 0 :(得分:4)
您需要tf.equal
进行逐元素比较:
t2 = tf.where(tf.equal(t, 2))
t = tf.constant([[1, 2, 3],
[4, 5, 6]])
t2 = tf.where(tf.equal(t, 2))
t3 = tf.gather_nd(t,t2)
t3_shape= tf.shape(t)[0]
with tf.Session() as sess:
print(sess.run([t3]))
# [array([2])]