e.g。给出m x n
张量,我试图找到大于阈值的元素。
看来这可以用tf.greater
来完成,但似乎我需要构建一个m x n
张量的阈值?
有什么好办法吗?
答案 0 :(得分:1)
看起来你没有长时间搜索:
import tensorflow as tf
x= tf.constant([[0, 1, 2], [3, 4, 5]], dtype=tf.float32)
out= tf.greater(x, 2.5)
with tf.Session() as sess:
print(sess.run(out))
给出:
[[False False False] [True True True]]
答案 1 :(得分:1)
这是一种计数大于阈值的元素数量的方法:
x = tf.constant([[1,2,3,4],[2,3,4,5],[3,4,5,6]])
threshold = 4
elements_gt = tf.math.greater(x,threshold)
num_elements_gt = tf.math.reduce_sum(tf.cast(elements_gt, tf.int32))
print(num_elements_gt)
计算tf.greater
时,可以使用tf.greater_equal
,tf.less
,tf.less_equal
,elements_gt
作为过滤器。