我正在基于急切的执行模式在Keras中实现自定义损失功能。问题在于,这种损失需要停止某些特定变量的梯度计算。在图形中执行时,我们可以使用此操作tf.stop_gradient
。实际上,对于Keras,GradientTape是自动创建的,我认为无法访问Tape变量来进行重置。
def virtual_adversarial_loss(X, DAE_encoder):
r_vadv = generate_virtual_adversarial_perturbation(X, DAE_encoder)
tape.reset()
tape.watch(X)
tape.watch(r_vadv)
p = DAE_encoder(X)
p = tape.watch(p)
q = DAE_encoder(x+r_vadv)
loss = kl(p, q)
return tf.identity(loss, name="vat_loss")
def virtual_adversarial_loss(X, DAE_encoder):
r_vadv = generate_virtual_adversarial_perturbation(X, DAE_encoder)
tf.stop_gradient(X)
tf.stop_gradient(r_vadv)
p = DAE_encoder(X)
p = tf.stop_gradient(p)
q = DAE_encoder(x+r_vadv)
loss = kl(p, q)
return tf.identity(loss, name="vat_loss")
这两个功能都不起作用。第一个问题是我不能使用内部Keras包装函数创建的tape变量。第二个问题是我不能将tf.stop_gradient
用于急切的执行模式。
有没有办法在Keras急切的执行模式下停止梯度计算?