我想用tensorflow 2.0
中的损失函数替换为梯度函数。
例如,我有一个损失函数,看起来像:
def loss_function(prediction):
# do some standard tensorflow things here
return loss
然后我使用tf.GradientTape
方法(即
with tf.GradientTape() as tape:
prediction = model(input)
loss = loss_function(prediction)
gradients = tf.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
我的问题是,我想更改梯度计算并自己针对当前自动计算的loss_function()
进行显式计算。
我猜想这与@tf.custom_gradient
装饰器有关,但是不确定如何使它工作以弥补损失。
我正在使用与顺序/功能api相对应的自定义训练循环。
答案 0 :(得分:0)
答案实际上很简单。首先定义一个@tf.custom_gradient()
函数,该函数定义渐变并通过它传递损耗,即
@tf.custom_gradient
def custom_grad(x):
def grad(dy, **kwargs):
# do something with dy here
return dy
return tf.identity(x), grad
并通过以下方法将原始函数中的损失变量传递给我们:
def loss_function(prediction):
# calculate the loss
return custom_grad(loss)