Tensorflow 2中的示例代码
writer = tf.summary.create_file_writer("/tmp/mylogs/tf_function")
@tf.function
def my_func(step):
with writer.as_default():
# other model code would go here
tf.summary.scalar("my_metric", 0.5, step=step)
for step in range(100):
my_func(step)
writer.flush()
但是它正在抛出警告。
警告:tensorflow:在触发tf.function跟踪的最后5次调用中有5次。追踪很昂贵 并且过多的跟踪可能是由于传递了python 对象而不是张量。另外,tf.function有 Experiment_relax_shapes =放宽参数形状的True选项 可以避免不必要的追溯。请参阅 https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args 和https://www.tensorflow.org/api_docs/python/tf/function了解更多 详细信息。
有更好的方法吗?
答案 0 :(得分:1)
tf.function
有一些“特殊之处”。我强烈建议您阅读这篇文章:https://www.tensorflow.org/tutorials/customization/performance
在这种情况下,问题在于,每次您使用不同的 input签名调用时,该函数都会被“追溯”(即,建立了一个新图形)。对于张量,输入签名指的是shape和dtype,但对于Python数字,每个新值均被解释为“不同”。在这种情况下,由于您使用每次更改的step
变量来调用函数,因此该函数也每次都被追溯。对于“真实”代码(例如,在函数内部调用模型),这将非常慢。
您可以通过简单地将step
转换为张量来解决它,在这种情况下,不同的值将不计为新的输入签名:
for step in range(100):
step = tf.convert_to_tensor(step, dtype=tf.int64)
my_func(step)
writer.flush()
或使用tf.range
直接获得张量:
for step in tf.range(100):
step = tf.cast(step, tf.int64)
my_func(step)
writer.flush()
这不应产生警告(并且速度要快得多)。