如何使用tf.where()根据条件替换特定值

时间:2019-03-28 08:58:32

标签: python tensorflow

我想替换条件下的值。
NumPy版本将如下所示

intensity=np.where(
  np.abs(intensity)<1e-4,
  1e-4,
  intensity)

但是TensorFlow在tf.where()中的用法有所不同
当我尝试这个

intensity=tf.where(
  tf.math.abs(intensity)<1e-4,
  1e-4,
  intensity)

我收到此错误

ValueError: Shapes must be equal rank, but are 0 and 4 for 'Select' (op: 'Select') with input shapes: [?,512,512,1], [], [?,512,512,1].

这是否意味着我应该为1e-4使用4维张量?

1 个答案:

答案 0 :(得分:1)

以下代码传递了错误

# Create an array which has small value (1e-4),  
# whose shape is (2,512,512,1)
small_val=np.full((2,512,512,1),1e-4).astype("float32")

# Convert numpy array to tf.constant
small_val=tf.constant(small_val)

# Use tf.where()
intensity=tf.where(
  tf.math.abs(intensity)<1e-4,
  small_val,
  intensity)

# Error doesn't occur
print(intensity.shape)
# (2, 512, 512, 1)