在tf.where()中使用==条件的问题

时间:2018-07-24 00:29:40

标签: tensorflow

我最近开始使用tensorflow并使用tf.where()函数。我注意到,每当我使用“ ==”条件时,它都会引发错误。例如,当我尝试以下操作时:

t = tf.constant([[1, 2, 3], 
                 [4, 5, 6]])

t2 = tf.where(t==2)
t3 = tf.gather_nd(t,t2)

t3_shape= tf.shape(t)[0]

with tf.Session() as sess:
    print(sess.run([t3]))

它引发以下错误:

  

InvalidArgumentError:WhereOp:未处理的输入尺寸:0

有人可以在这里解释什么错误吗? 预先感谢!

1 个答案:

答案 0 :(得分:4)

您需要tf.equal进行逐元素比较:

t2 = tf.where(tf.equal(t, 2))

t = tf.constant([[1, 2, 3],
                 [4, 5, 6]])

t2 = tf.where(tf.equal(t, 2))
t3 = tf.gather_nd(t,t2)   
t3_shape= tf.shape(t)[0]

with tf.Session() as sess:
    print(sess.run([t3]))

# [array([2])]