我遇到错误,代码如下。我想检查函数floor的梯度,发生错误:
import numpy as np
import tensorflow as tf
def floor(x):
return tf.floor(x)
# code
w1 = tf.Variable([[1.5, 0.5, -0.5, -1.5]])
res = floor(w1)
grads = tf.gradients(res, [w1])
with tf.Session() as sess:
tf.global_variables_initializer().run()
print(sess.run(grads))
答案 0 :(得分:1)
TensorFlow中floor()
函数的渐变定义为返回None
,因为⌊x⌋的渐变在任何地方都是0(除了整数),所以这允许后端代码处理它因为没有联系。
有关相关的git问题,请参阅here。