label = tf.constant([0,1,2,3,4,4,5,5])
我有一个张量,例如,高于一个。 我想过滤元素为4的张量。输出张量应为[4,4]。 怎么实现呢?感谢。
答案 0 :(得分:0)
只需使用tf.where
来获取条件为真的索引,并使用tf.gather
来收集指定的值
import tensorflow as tf
label = tf.constant([0,1,2,3,4,4,5,5])
filtered = tf.gather(label, tf.where(tf.equal(label, 4)))
sess = tf.Session()
print(sess.run(filtered))
[[4] [4]]