使用@tffunction的Tensorflow2警告

时间:2019-11-21 10:05:47

标签: tensorflow

Tensorflow 2中的示例代码

writer = tf.summary.create_file_writer("/tmp/mylogs/tf_function")

@tf.function
def my_func(step):
  with writer.as_default():
    # other model code would go here
    tf.summary.scalar("my_metric", 0.5, step=step)

for step in range(100):
  my_func(step)
  writer.flush()

但是它正在抛出警告。

  

警告:tensorflow:在触发tf.function跟踪的最后5次调用中有5次。追踪很昂贵   并且过多的跟踪可能是由于传递了python   对象而不是张量。另外,tf.function有   Experiment_relax_shapes =放宽参数形状的True选项   可以避免不必要的追溯。请参阅   https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args   和https://www.tensorflow.org/api_docs/python/tf/function了解更多   详细信息。

有更好的方法吗?

1 个答案:

答案 0 :(得分:1)

tf.function有一些“特殊之处”。我强烈建议您阅读这篇文章:https://www.tensorflow.org/tutorials/customization/performance

在这种情况下,问题在于,每次您使用不同的 input签名调用时,该函数都会被“追溯”(即,建立了一个新图形)。对于张量,输入签名指的是shape和dtype,但对于Python数字,每个新值均被解释为“不同”。在这种情况下,由于您使用每次更改的step变量来调用函数,因此该函数也每次都被追溯。对于“真实”代码(例如,在函数内部调用模型),这将非常慢。

您可以通过简单地将step转换为张量来解决它,在这种情况下,不同的值将计为新的输入签名:

for step in range(100):
    step = tf.convert_to_tensor(step, dtype=tf.int64)
    my_func(step)
    writer.flush()

或使用tf.range直接获得张量:

for step in tf.range(100):
    step = tf.cast(step, tf.int64)
    my_func(step)
    writer.flush()

这不应产生警告(并且速度要快得多)。