如何在急切执行中暂停和恢复渐变编带?

时间:2018-06-15 20:02:24

标签: python tensorflow

我的任务看起来像这样:

# compute estimates from input
net_estimate = my_model(inputs)
# use this estimate to compute a target
target_estimate = lots_of_computations(net_estimate)
# compute loss
loss = compute_loss(net_estimate, target_estimate)

(对于某些情况,这是一个强化学习任务,由此产生的状态 - 和奖励 - 取决于网络采取的行动。)

问题是我不想(实际上不能)计算lots_of_computations的梯度。理想情况下,我想暂停并恢复渐变编带

with tf.GradientTape() as tape:
  net_estimate = my_model(inputs)
# target_estimate should be considered a constant
target_estimate = lots_of_computations(net_estimate)
with tape.resume():
  loss = compute_loss(net_estimate, target_estimate)
tape.gradient(loss, my_model.params)

GradientTape似乎没有提供类似的东西。有没有办法在急切模式下实现这一目标?我目前的解决方法是计算net_estimate两次,但这显然不是最理想的。

1 个答案:

答案 0 :(得分:1)

tf.GradientTape.stop_recording可能就是你要找的东西。

最近推出(在TensorFlow 1.8之后),所以目前你需要使用TensorFlow 1.9.0的候选版本。

希望有所帮助。