在全局上下文中使用一个GradientTape

时间:2019-10-29 17:11:47

标签: python tensorflow tensorflow2.0

我想使用GradientTape在急切的执行模式下观察梯度。是否可以创建一次GradientTape,然后记录所有内容,就好像它具有全局上下文一样?

这是我想做的事的一个例子:

import numpy as np
import tensorflow as tf

x = tf.Variable(np.ones((2,)))
y=2*x
z=2*y
tf.gradients(z, x) # RuntimeError, not supported in eager execution

现在,可以轻松解决此问题:

with tf.GradientTape() as g:
    y = 2*x
    z = 2*y

g.gradient(y, x) # this works

但是问题是我经常没有紧接彼此的y和z定义。例如,如果代码在Jupyter笔记本中执行并且位于不同的单元格怎么办?

我可以定义一个可以全局监视所有内容的GradientTape吗?

1 个答案:

答案 0 :(得分:2)

我找到了解决方法:

import numpy as np
import tensorflow as tf

# persistent is not necessary for g to work globally
# it only means that gradients can be computed more than once,
# which is important for the interactive jupyter notebook use-case
g = tf.GradientTape(persistent=True)

# this is the workaround
g.__enter__()

# you can execute this anywhere, also splitted into separate cells
x = tf.Variable(np.ones((2,)))
y = 2*x
z = 2*y

g.gradient(z, x)