使用tf.where的TensorFlow渐变在不应该返回NaN时返回NaN

时间:2018-05-05 08:24:15

标签: python tensorflow gradient

以下是可重现的代码。如果你运行它,你会发现在第一次运行中,结果是nan,而第二种情况给出正确的梯度值0.5。但是根据指定的tf.where和条件,它们应该返回相同的值。我也根本不理解为什么tf.where函数渐变为1或-1的纳米,这对我来说似乎是完全精细的输入值。

tf.reset_default_graph()
x = tf.get_variable('x', shape=[1])
condition = tf.less(x, 0.0)
output = tf.where(condition, -tf.log(-x + 1), tf.log(x + 1))
deriv = tf.gradients(output, x)
with tf.Session() as sess:
    print(sess.run(deriv, {x:np.array([-1])}))

logg = -tf.log(-x+1)
derivv = tf.gradients(logg, x)
with tf.Session() as sess:
    print(sess.run(derivv, {x:np.array([-1])}))

感谢您的评论!

1 个答案:

答案 0 :(得分:1)

正如@mikkola提供的github issue中所解释的那样,问题源于tf.where的内部实施。基本上,计算两个备选方案(及其梯度),并且通过乘以条件仅选择正确的部分。唉,如果选择的部分的渐变为infnan,即使乘以0,您也会得到nan,最终传播到结果

由于该问题已于2016年5月提交(该问题的张量流程为v0.7!)并且此后未进行修补,因此可以放心地认为这不会很快成为现实,并开始寻找解决方法。

修复它的最简单方法是修改语句,使它们始终有效且可区分,即使对于不打算选择的值也是如此。

一般技术是将输入值剪切到其有效域内。因此,在您的情况下,您可以使用

cond = tf.less(x, 0.0)
output = tf.where(cond,
  -tf.log(-tf.where(cond, x, 0) + 1),
  tf.log(tf.where(cond, 0, x) + 1))

在您的特定情况下,只使用

会更简单
output = tf.sign(x) * tf.log(tf.abs(x) + 1)