张量流条件:检查张量内的值是否为零或更大

时间:2018-02-21 15:07:34

标签: python tensorflow

如果我有以下张量:

pmi=tf.constant([[1.5,0.0,0.0],[0.0,0.0,2.9],[1.001,5,1]])

我希望有一个相应的张量Fpmi(或缩放器),这样当PMI张量内的元素大于0时,Fpmi中的元素应为1,而当pmi中的元素为0时,Fpmi中的元素= 0.0005。

我很感激任何建议。

2 个答案:

答案 0 :(得分:1)

使用tf.where,您可以有条件地从两个常量张量中返回元素:

a = tf.constant(1, shape=pmi.shape, dtype=tf.float32)
b = tf.constant(0.0005, shape=pmi.shape, dtype=tf.float32)

tf.where(tf.greater(pmi, 0), a, b).eval()

#array([[  1.00000000e+00,   5.00000024e-04,   5.00000024e-04],
#       [  5.00000024e-04,   5.00000024e-04,   1.00000000e+00],
#       [  1.00000000e+00,   1.00000000e+00,   1.00000000e+00]], dtype=float32)

答案 1 :(得分:0)

如果不大于阈值就想保留原来的值,把tf.where的第二个参数替换成原来的数据。

data = tf.random.uniform(shape=(2, 2))

threshold = 0.4
fill_value_if_bigger = tf.constant(0.5, shape=data.shape, dtype=tf.float32)
replaced = tf.where(tf.greater(data, threshold), fill_value_if_bigger, data).numpy()