我应该对所有功能都使用@ tf.function吗?

时间:2020-01-21 18:21:10

标签: tensorflow keras tensorflow2.0 tf.keras tensorflow2.x

@tf.function上的official tutorial说:

要获得最佳性能并使模型可部署到任何地方, 使用tf.function从程序中绘制图形。谢谢 AutoGraph,可以使用数量惊人的Python代码 tf.function,但仍然要提防陷阱。

主要要点和建议是:

  • 不要依赖于对象更改或列表追加之类的Python副作用。
  • tf.function最适合TensorFlow操作,而不是NumPy操作或Python原语。
  • 如有疑问,请在y惯用法中使用for。

它只提到如何实施@tf.function带注释的功能,而没有提及何时使用它。

是否有试探性决定如何确定是否至少应尝试使用tf.function来注释函数?似乎没有理由不这样做,除非我懒于消除副作用或更改诸如range()-> tf.range()之类的东西。但是如果我愿意这样做...

是否有理由不对所有功能都使用@tf.function

3 个答案:

答案 0 :(得分:14)

TLDR:这取决于您的功能以及您是生产还是开发。如果您希望能够轻松调试功能,或者该功能受AutoGraph或tf.v1代码兼容性的限制,请不要使用tf.function。 我强烈建议观看Inside TensorFlow关于AutoGraphFunctions, not Sessions的话题。

以下,我将详细说明原因,这些原因均来自Google在线提供的信息。

通常,tf.function装饰器使函数被编译为执行TensorFlow图的可调用函数。这需要:

  • 如有必要,可以通过AutoGraph转换代码(包括从带注释的函数调用的任何函数)
  • 跟踪并执行生成的图形代码

There is detailed information available on the design ideas behind this.

tf.function装饰函数的好处

一般利益

  • 执行速度更快,尤其是当函数包含许多小操作(Source)

对于具有Python代码的函数/通过tf.function装饰使用AutoGraph

如果要使用AutoGraph,强烈建议使用tf.function而不是直接调用AutoGraph。 这样做的原因包括:自动控制依赖项,某些API需要它,更多的缓存,以及异常帮助器(Source)

tf.function装饰函数的缺点

一般弊端

  • 如果该功能仅包含少量昂贵的操作,则提速(Source)

对于具有Python代码的函数/通过tf.function装饰使用AutoGraph

  • 没有异常捕获(应该在热切的模式下;在修饰的函数之外)(Source)
  • 调试要困难得多
  • 由于隐藏的副作用和TF控制流而引起的限制

Detailed information on AutoGraph limitations is available.

用于带有tf.v1代码的功能

  • 不允许在tf.function中多次创建变量,但这可能会随着tf.v1代码(Source)的淘汰而改变

用于带有tf.v2代码的功能

  • 没有具体缺点

限制示例

多次创建变量

不允许多次创建变量,例如以下示例中的v

@tf.function
def f(x):
    v = tf.Variable(1)
    return tf.add(x, v)

f(tf.constant(2))

# => ValueError: tf.function-decorated function tried to create variables on non-first call.

在下面的代码中,通过确保仅创建一次self.v来减轻这种情况:

class C(object):
    def __init__(self):
        self.v = None
    @tf.function
    def f(self, x):
        if self.v is None:
            self.v = tf.Variable(1)
        return tf.add(x, self.v)

c = C()
print(c.f(tf.constant(2)))

# => tf.Tensor(3, shape=(), dtype=int32)

AutoGraph无法捕获的隐藏副作用

在此示例中,诸如self.a之类的更改无法隐藏,这导致错误,因为尚未进行跨功能分析((Source)):

class C(object):
    def change_state(self):
        self.a += 1

    @tf.function
    def f(self):
        self.a = tf.constant(0)
        if tf.constant(True):
            self.change_state() # Mutation of self.a is hidden
        tf.print(self.a)

x = C()
x.f()

# => InaccessibleTensorError: The tensor 'Tensor("add:0", shape=(), dtype=int32)' cannot be accessed here: it is defined in another function or code block. Use return values, explicit Python locals or TensorFlow collections to access it. Defined in: FuncGraph(name=cond_true_5, id=5477800528); accessed from: FuncGraph(name=f, id=5476093776).

完全改变没有问题:

class C(object):
    @tf.function
    def f(self):
        self.a = tf.constant(0)
        if tf.constant(True):
            self.a += 1 # Mutation of self.a is in plain sight
        tf.print(self.a)

x = C()
x.f()

# => 1

由于TF控制流程而受到限制的示例

此if语句会导致错误,因为需要为TF控制流定义else的值:

@tf.function
def f(a, b):
    if tf.greater(a, b):
        return tf.constant(1)

# If a <= b would return None
x = f(tf.constant(3), tf.constant(2))   

# => ValueError: A value must also be returned from the else branch. If a value is returned from one branch of a conditional a value must be returned from all branches.

答案 1 :(得分:3)

tf.function在创建和使用计算图时很有用,应在培训和部署中使用它们,但是大多数功能不需要它。

让我们说我们正在构建一个特殊的层,它将与更大的模型分开。我们不希望在构造该层的函数上方使用tf.function装饰器,因为它只是该层外观的定义。

另一方面,可以说我们将进行预测或使用某些功能继续进行训练。我们想要装饰器tf.function,因为我们实际上是在使用计算图来获得一些值。

一个很好的例子是构造一个编码器-解码器模型。 不要将装饰器放在创建编码器或解码器或任何层的函数周围,这仅是其作用的定义。 不要将装饰器放在“火车”或“预测”方法的周围,因为实际上它们将使用计算图进行计算。

答案 2 :(得分:1)

根据我的理解,并根据文档,强烈建议使用tf.function主要是为了加快代码的速度,因为tf.function所包装的代码将转换为图形,因此仍有空间急切地执行某些优化(例如,操作修剪,折叠等),而当急于运行相同的代码时,可能无法执行这些优化。

但是,在某些情况下,使用tf.function可能会产生额外的开销或不会导致明显的加速。一种值得注意的情况是,包装的函数在代码中,并且仅使用了几次,因此调用图的开销可能相对较大。另一种情况是,大多数计算已在加速器设备(例如GPU,TPU)上完成,因此通过图形计算获得的加速效果可能并不明显。

还有a section in the documentation,其中讨论了各种情况下的加速,并且在本节的开头提到了以上两种情况:

仅将张量使用函数包装在tf.function中并不会自动加速代码。对于在单个计算机上多次调用的小函数,调用图形或图形片段的开销可能会占主导地位。另外,如果大多数计算已经在加速器上进行,例如GPU重卷积的堆栈,则图形加速不会很大。

对于复杂的计算,图形可以显着提高速度。这是因为图形减少了Python与设备之间的通信并提高了速度。

但是,归根结底,如果它适用于您的工作流程,我认为针对特定用例和环境确定此结果的最佳方法是在执行时对代码进行概要分析急切模式(即不使用tf.function)与在图形模式下执行时(即广泛使用tf.function)。