因此,here表示间接修改不起作用,这意味着更改将是不可见的(无论如何,不可见的更改意味着什么?)
但是此代码正确计算了梯度:
import tensorflow as tf
class C:
def __init__(self):
self.x = tf.Variable(2.0)
@tf.function
def change(self):
self.x.assign_add(2.0)
@tf.function
def func(self):
self.change()
return self.x * self.x
c = C()
with tf.GradientTape() as tape:
y = c.func()
print(tape.gradient(y, c.x)) # --> tf.Tensor(8.0, shape=(), dtype=float32)
我在这里想念东西吗?
谢谢
答案 0 :(得分:0)
文档缺少详细信息,应予以澄清-“不可见”表示AutoGraph的分析器未检测到更改。由于AutoGraph一次分析一个功能,因此分析器看不到对另一功能所做的修改。
但是,此警告不适用于具有副作用的操作,例如对TF变量的修改-那些变量仍将在图表中正确连接。因此您的代码应该可以正常工作。
此限制仅适用于对纯Python对象(列表,字典等)所做的某些更改,并且仅在使用控制流时存在问题。
例如,这是对您的代码的修改,将无法正常工作:
class C:
def __init__(self):
self.x = None
def reset(self):
self.x = tf.constant(10)
def change(self):
self.x += 1
@tf.function
def func(self):
self.reset()
for i in tf.range(3):
self.change()
return self.x * self.x
c = C()
print(c.func())
该错误消息相当模糊,但是如果您尝试不使用tf.while_loop
访问在loop_vars
主体内创建的op的结果,则会产生与该错误相同的错误消息:
<ipython-input-18-23f1641cfa01>:20 func *
return self.x * self.x
... more internal frames ...
InaccessibleTensorError: The tensor 'Tensor("add:0", shape=(),
dtype=int32)' cannot be accessed here: it is defined in another function or
code block. Use return values, explicit Python locals or TensorFlow
collections to access it. Defined in: FuncGraph(name=while_body_685,
id=5029696157776); accessed from: FuncGraph(name=func, id=5029690557264).