def train_one_step():
with tf.GradientTape() as tape:
a = tf.random.normal([1, 3, 1])
b = tf.random.normal([1, 3, 1])
loss = mse(a, b)
tf.print('inner tf print', loss)
print("inner py print", loss)
return loss
@tf.function
def train():
loss = train_one_step()
tf.print('outer tf print', loss)
print('outer py print', loss)
return loss
loss = train()
tf.print('outest tf print', loss)
print("outest py print", loss)
我想更多地了解tf.function。我用不同的方法在四个地方打印了损失。并产生这样的结果
inner py print Tensor("mean_absolute_error/weighted_loss/value:0", shape=(), dtype=float32)
outer py print Tensor("mean_absolute_error/weighted_loss/value:0", shape=(), dtype=float32)
inner tf print 1.82858419
outer tf print 1.82858419
outest tf print 1.82858419
outest py print tf.Tensor(1.8285842, shape=(), dtype=float32)
答案 0 :(得分:1)
print
是正常的python打印。 tf.print
是张量流图的一部分。
在渴望模式下,tensorflow将直接执行该图。这就是为什么在@tf.function
函数之外,python print的输出是一个数字(tensorflow直接执行该图并将该数字提供给正常的print函数),这也是为什么tf.print立即打印的原因。
另一方面,在@tf.function
函数内部,tensorflow不会立即执行该图。相反,它将把您调用的张量流函数“堆叠”到更大的图中,我们将在@tf.function
的末尾立即执行该函数。
这就是为什么python打印不给您@tf.function
函数内部的数字的原因(此时图形尚未执行)。但是,函数结束后,图形将与图形中的tf.print
一起执行。因此,tf.print
是在python打印之后打印的,并为您提供实际的丢失编号。
答案 1 :(得分:1)
我在由三部分组成的文章中涵盖并回答了您的所有问题:“分析tf.function以发现AutoGraph的优势和精妙之处”:part 1,part 2,part 3。
总结并回答您的3个问题:
tf.print
是一个Tensorflow构造,默认情况下会以标准错误进行打印,更重要的是,它在评估时会产生操作。
运行某项操作时,也渴望执行,它或多或少地以与Tensorflow 1.x相同的方式产生一个“节点”。
tf.function
能够捕获tf.print
的生成的操作并将其转换为图形节点。
相反,print
是一个Python构造,默认情况下会在标准输出上打印,并且在执行时不不生成任何操作。因此,tf.function
无法将其转换为等效的图形,而只能在函数跟踪期间执行。
我在上一点已经回答了这个问题,但是再次重申,print
仅在函数跟踪期间执行,而tf.print
仅在跟踪过程中以及在其图形表示时执行(在tf.function
成功将函数转换为图形之后)。
tf.print
不在print
之前或之后运行。在急切的执行中,Python解释器一旦找到该语句,便会对它们进行评估。渴望执行的唯一区别是输出流。
无论如何,我建议您阅读链接的三篇文章,因为它们详细介绍了tf.function
的这一点和其他特点。