TensorFlow梯度w.r.t. while_loop主体内部的中间结果

时间:2019-02-27 13:59:42

标签: tensorflow gradient

我已经使用tf.scan使用TensorFlow实现了一些递归模型,目前我正在尝试获取输出损耗w.r.t的梯度。出于研究目的在每个时间步上都注明。

可以将关键过程简化为计算梯度w.r.t. while_loop主体中的中间值:

import tensorflow as tf

tf.InteractiveSession()

sequence = tf.constant([2, 3, 4, 5])
def body(t, previous_s):
  intermediate_s = previous_s * sequence[t]
  return t + 1, intermediate_s

t0, s0 = 0, 1
_, s4 = tf.while_loop(lambda t, _: t < sequence.shape[0], body, (t0, s0))
print('s4 =', s4.eval())

输出为

s4 = 120

和while循环期间计算出的中间结果是

s1 = 2, s2 = 6, s3 = 24

所以有一种方法可以访问这些张量,因此我可以使用tf.gradient计算相应的梯度,比如说dS4/dS1

ds4_ds1 = tf.gradient(s4, s1)[0]

0 个答案:

没有答案