在Tensorflow 2.0中替换损失函数的梯度计算

时间:2019-11-20 11:49:45

标签: tensorflow tensorflow2.0 loss-function

我想用tensorflow 2.0中的损失函数替换为梯度函数。

例如,我有一个损失函数,看起来像:

def loss_function(prediction):
    # do some standard tensorflow things here
    return loss

然后我使用tf.GradientTape方法(即

)应用渐变
with tf.GradientTape() as tape:

     prediction = model(input)
     loss = loss_function(prediction)

gradients = tf.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))

我的问题是,我想更改梯度计算并自己针对当前自动计算的loss_function()进行显式计算。

我猜想这与@tf.custom_gradient装饰器有关,但是不确定如何使它工作以弥补损失。

我正在使用与顺序/功能api相对应的自定义训练循环。

1 个答案:

答案 0 :(得分:0)

答案实际上很简单。首先定义一个@tf.custom_gradient()函数,该函数定义渐变并通过它传递损耗,即

@tf.custom_gradient
def custom_grad(x):
    def grad(dy, **kwargs):
        # do something with dy here
        return dy 
    return tf.identity(x), grad

并通过以下方法将原始函数中的损失变量传递给我们:

def loss_function(prediction):
    # calculate the loss
    return custom_grad(loss)