张量中的TensorFlow修剪值

时间:2017-09-25 15:25:54

标签: python tensorflow

如何在TensorFlow张量中执行以下操作?

在矩阵A中:如果A [i,j]> 1然后A [i,j] = 1

(在numpy我会这样做:A[A>1] = 1

1 个答案:

答案 0 :(得分:2)

您可以使用tf.minimum,它按元素进行最小化计算;设置y = 1后,x中的值将被裁剪为1的最大值:

A = tf.constant([-1, 0, 1, 3, 4])

A_clipped = tf.minimum(A, 1)

sess = tf.InteractiveSession()
A_clipped.eval()
# array([-1,  0,  1,  1,  1], dtype=int32)

另一个选项是使用tf.where来设置值:

tf.where(A > 1, tf.constant(1, shape=A.shape), A).eval()
# array([-1,  0,  1,  1,  1], dtype=int32)

如果您需要更新变量A

A = tf.Variable([-1, 0, 1, 3, 4])
​
tf.global_variables_initializer().run()
tf.assign(A, tf.minimum(A, 1)).eval()

A.eval()
# array([-1,  0,  1,  1,  1], dtype=int32)