Tensorflow钳位值超出特定范围

时间:2016-08-08 11:55:28

标签: python-3.x tensorflow

我一直在使用tensorflow来实现卷积神经网络, 我要求输出值小于给定值 MAX_VAL

我尝试创建一个填充 MAX_VAL 的矩阵,然后使用 tf.select tf.greater

filled = tf.fill(output.get_shape(),MAX_VAL)
modoutput = tf.select(tf.greater(output, filled), filled, output)

但这不起作用,因为输出的形状不是静态知道的: 它是[?,30]并且tf.fill需要一个明确的形状。

我知道如何实现这个目标?

2 个答案:

答案 0 :(得分:2)

有一种替代解决方案,使用tf.fill()就像您的初始版本一样。不要使用Tensor.get_shape()来获取output的静态形状,而是使用tf.shape()运算符在步骤运行时获得output的动态形状:

output = ...
filled = tf.fill(tf.shape(output), MAX_VAL)
modoutput = tf.select(tf.greater(output, filled), filled, output)

(另请注意,tf.clip_by_value()运算符可能对您有用。)

答案 1 :(得分:0)

我想出了办法。

我没有使用 tf.fill ,而是使用 tf.ones_like

filled = MAX_VAL*tf.ones_like(output)
modoutput = tf.select(tf.greater(output, filled), filled, output)

请提及是否有更快或更好的方法可以做到这一点。