我的任务看起来像这样:
# compute estimates from input
net_estimate = my_model(inputs)
# use this estimate to compute a target
target_estimate = lots_of_computations(net_estimate)
# compute loss
loss = compute_loss(net_estimate, target_estimate)
(对于某些情况,这是一个强化学习任务,由此产生的状态 - 和奖励 - 取决于网络采取的行动。)
问题是我不想(实际上不能)计算lots_of_computations
的梯度。理想情况下,我想暂停并恢复渐变编带
with tf.GradientTape() as tape:
net_estimate = my_model(inputs)
# target_estimate should be considered a constant
target_estimate = lots_of_computations(net_estimate)
with tape.resume():
loss = compute_loss(net_estimate, target_estimate)
tape.gradient(loss, my_model.params)
但GradientTape
似乎没有提供类似的东西。有没有办法在急切模式下实现这一目标?我目前的解决方法是计算net_estimate
两次,但这显然不是最理想的。
答案 0 :(得分:1)
tf.GradientTape.stop_recording
可能就是你要找的东西。
最近推出(在TensorFlow 1.8之后),所以目前你需要使用TensorFlow 1.9.0的候选版本。
希望有所帮助。