停止Keras渴望执行模式的梯度计算

时间:2018-09-19 09:19:18

标签: python tensorflow keras

我正在基于急切的执行模式在Keras中实现自定义损失功能。问题在于,这种损失需要停止某些特定变量的梯度计算。在图形中执行时,我们可以使用此操作tf.stop_gradient。实际上,对于Keras,GradientTape是自动创建的,我认为无法访问Tape变量来进行重置。

def virtual_adversarial_loss(X, DAE_encoder):
    r_vadv = generate_virtual_adversarial_perturbation(X, DAE_encoder)
    tape.reset()
    tape.watch(X)
    tape.watch(r_vadv)
    p =  DAE_encoder(X)
    p = tape.watch(p)
    q = DAE_encoder(x+r_vadv)
    loss = kl(p, q)
    return tf.identity(loss, name="vat_loss")

def virtual_adversarial_loss(X, DAE_encoder):
    r_vadv = generate_virtual_adversarial_perturbation(X, DAE_encoder)
    tf.stop_gradient(X)
    tf.stop_gradient(r_vadv)
    p =  DAE_encoder(X)
    p = tf.stop_gradient(p)
    q = DAE_encoder(x+r_vadv)
    loss = kl(p, q)
    return tf.identity(loss, name="vat_loss")

这两个功能都不起作用。第一个问题是我不能使用内部Keras包装函数创建的tape变量。第二个问题是我不能将tf.stop_gradient用于急切的执行模式。

有没有办法在Keras急切的执行模式下停止梯度计算?

0 个答案:

没有答案