TF 2.0 while_loop和parallel_iterations

时间:2019-12-12 06:42:57

标签: python tensorflow parallel-processing

我正在尝试使用tf.while_loop并行运行循环。但是,在以下玩具示例中,循环似乎没有并行运行。

iteration = tf.constant(0)
c = lambda i: tf.less(i, 1000)
def print_fun(iteration):
    print(f"This is iteration {iteration}")
    iteration+=1
    return (iteration,)
r = tf.while_loop(c, print_fun, [iteration], parallel_iterations=10)

i = tf.constant(0)
c = lambda i: tf.less(i, 1000)
b = lambda i: (tf.add(i, 1),)
r = tf.while_loop(c, b, [i])

是什么导致tf.while_loop无法并行化循环?

此外,如果维护Tensorflow文档的任何人都看到此页面,则他/她应在第一个示例中修复该错误。参见讨论here

谢谢。

1 个答案:

答案 0 :(得分:1)

parallel_iterations在急切模式下运行没有任何意义,但是您始终可以使用tf.function装饰器并获得显着的加速。可以在这张图片中看到:running times

您可以像这样用tf.while_loop包裹tf.function

@tf.function
def run_graph():
    iteration = tf.constant(0)
    r = tf.while_loop(c, print_fun, [iteration], parallel_iterations=4)

,然后在需要时致电run_graph