Tensorflow 2.0 Autograph间接修改(隐藏状态)在不应该运行时起作用

时间:2019-12-05 00:14:30

标签: python tensorflow tensorflow2.0 gradienttape

因此,here表示间接修改不起作用,这意味着更改将是不可见的(无论如何,不​​可见的更改意味着什么?)

但是此代码正确计算了梯度:

import tensorflow as tf


class C:
    def __init__(self):
        self.x = tf.Variable(2.0)

    @tf.function
    def change(self):
        self.x.assign_add(2.0)

    @tf.function
    def func(self):
        self.change()
        return self.x * self.x


c = C()
with tf.GradientTape() as tape:
    y = c.func()
print(tape.gradient(y, c.x)) # --> tf.Tensor(8.0, shape=(), dtype=float32)

我在这里想念东西吗?

谢谢

1 个答案:

答案 0 :(得分:0)

文档缺少详细信息,应予以澄清-“不可见”表示AutoGraph的分析器未检测到更改。由于AutoGraph一次分析一个功能,因此分析器看不到对另一功能所做的修改。

但是,此警告不适用于具有副作用的操作,例如对TF变量的修改-那些变量仍将在图表中正确连接。因此您的代码应该可以正常工作。

此限制仅适用于对纯Python对象(列表,字典等)所做的某些更改,并且仅在使用控制流时存在问题。

例如,这是对您的代码的修改,将无法正常工作:

class C:
    def __init__(self):
        self.x = None

    def reset(self):
        self.x = tf.constant(10)

    def change(self):
        self.x += 1

    @tf.function
    def func(self):
      self.reset()
      for i in tf.range(3):
        self.change()
      return self.x * self.x


c = C()
print(c.func())

该错误消息相当模糊,但是如果您尝试不使用tf.while_loop访问在loop_vars主体内创建的op的结果,则会产生与该错误相同的错误消息:

    <ipython-input-18-23f1641cfa01>:20 func  *
        return self.x * self.x

    ... more internal frames ...

    InaccessibleTensorError: The tensor 'Tensor("add:0", shape=(),
dtype=int32)' cannot be accessed here: it is defined in another function or
code block. Use return values, explicit Python locals or TensorFlow
collections to access it. Defined in: FuncGraph(name=while_body_685,
id=5029696157776); accessed from: FuncGraph(name=func, id=5029690557264).