我已经使用tf.scan使用TensorFlow实现了一些递归模型,目前我正在尝试获取输出损耗w.r.t的梯度。出于研究目的在每个时间步上都注明。
可以将关键过程简化为计算梯度w.r.t. while_loop主体中的中间值:
import tensorflow as tf
tf.InteractiveSession()
sequence = tf.constant([2, 3, 4, 5])
def body(t, previous_s):
intermediate_s = previous_s * sequence[t]
return t + 1, intermediate_s
t0, s0 = 0, 1
_, s4 = tf.while_loop(lambda t, _: t < sequence.shape[0], body, (t0, s0))
print('s4 =', s4.eval())
输出为
s4 = 120
和while循环期间计算出的中间结果是
s1 = 2, s2 = 6, s3 = 24
所以有一种方法可以访问这些张量,因此我可以使用tf.gradient计算相应的梯度,比如说dS4/dS1
:
ds4_ds1 = tf.gradient(s4, s1)[0]