尝试了解AutoGraph和tf.function:tf.function中的打印损失

时间:2019-05-25 08:44:12

标签: tensorflow tensorflow2.0

def train_one_step():
    with tf.GradientTape() as tape:
        a = tf.random.normal([1, 3, 1])
        b = tf.random.normal([1, 3, 1])
        loss = mse(a, b)

    tf.print('inner tf print', loss)
    print("inner py print", loss)

    return loss


@tf.function
def train():
    loss = train_one_step()

    tf.print('outer tf print', loss)
    print('outer py print', loss)

    return loss

loss = train()
tf.print('outest tf print', loss)
print("outest py print", loss)

我想更多地了解tf.function。我用不同的方法在四个地方打印了损失。并产生这样的结果

inner py print Tensor("mean_absolute_error/weighted_loss/value:0", shape=(), dtype=float32)
outer py print Tensor("mean_absolute_error/weighted_loss/value:0", shape=(), dtype=float32)
inner tf print 1.82858419
outer tf print 1.82858419
outest tf print 1.82858419
outest py print tf.Tensor(1.8285842, shape=(), dtype=float32)
  1. tf.print和python print有什么区别?
  2. 看起来像是在执行图形评估时将执行python打印,但是仅在执行时才执行tf打印吗?
  3. 以上仅适用于有tf.function装饰器的情况?除此之外,tf.print在python打印之前运行?

2 个答案:

答案 0 :(得分:1)

print是正常的python打印。 tf.print是张量流图的一部分。 在渴望模式下,tensorflow将直接执行该图。这就是为什么在@tf.function函数之外,python print的输出是一个数字(tensorflow直接执行该图并将该数字提供给正常的print函数),这也是为什么tf.print立即打印的原因。

另一方面,在@tf.function函数内部,tensorflow不会立即执行该图。相反,它将把您调用的张量流函数“堆叠”到更大的图中,我们将在@tf.function的末尾立即执行该函数。

这就是为什么python打印不给您@tf.function函数内部的数字的原因(此时图形尚未执行)。但是,函数结束后,图形将与图形中的tf.print一起执行。因此,tf.print是在python打印之后打印的,并为您提供实际的丢失编号。

答案 1 :(得分:1)

我在由三部分组成的文章中涵盖并回答了您的所有问题:“分析tf.function以发现AutoGraph的优势和精妙之处”:part 1part 2part 3

总结并回答您的3个问题:

  • tf.print和python print有什么区别?

tf.print是一个Tensorflow构造,默认情况下会以标准错误进行打印,更重要的是,它在评估时会产生操作。

运行某项操作时,也渴望执行,它或多或少地以与Tensorflow 1.x相同的方式产生一个“节点”。

tf.function能够捕获tf.print的生成的操作并将其转换为图形节点。

相反,print是一个Python构造,默认情况下会在标准输出上打印,并且在执行时不生成任何操作。因此,tf.function无法将其转换为等效的图形,而只能在函数跟踪期间执行。

  • 看起来像是在执行图形评估时将执行python打印,但是仅在执行时才执行tf打印吗?

我在上一点已经回答了这个问题,但是再次重申,print仅在函数跟踪期间执行,而tf.print仅在跟踪过程中以及在其图形表示时执行(在tf.function成功将函数转换为图形之后)。

  • 以上仅适用于有tf.function装饰器的情况?除此之外,tf.print在python打印之前运行?

tf.print不在print之前或之后运行。在急切的执行中,Python解释器一旦找到该语句,便会对它们进行评估。渴望执行的唯一区别是输出流。

无论如何,我建议您阅读链接的三篇文章,因为它们详细介绍了tf.function的这一点和其他特点。