答案 0 :(得分:0)
通常,这可以通过tf.custom_gradient
完成。这样,您就可以编写一个正向函数以及一个自定义梯度函数,该函数取决于“传入”梯度(来自模型下方的各层)以及正向函数。
有了它,您可以创建一个像这样的函数:
@tf.custom_gradient
def sign_with_grad(x):
output = tf.sign(x)
def grad_fn(dy):
check_range = tf.where(tf.less_equal(tf.abs(x), 1.), 1., 0.)
return dy*check_range
return output, grad_fn
在这里,我们编写了一个自定义渐变,该渐变简单地通过了上一层的渐变-只是将输入超出[-1,1]范围的所有元素清零。我们返回正向函数(正负号)的结果以及梯度函数本身。装饰者负责处理其余部分。
请注意,我没有检查代码是否可以运行-让我知道是否无法运行!特别是,less_equal
和或where
检查可能需要显式广播-例如使用tf.ones_like(x)
代替1.
(等于0)。