我正在尝试使用tf.while_loop
并行运行循环。但是,在以下玩具示例中,循环似乎没有并行运行。
iteration = tf.constant(0)
c = lambda i: tf.less(i, 1000)
def print_fun(iteration):
print(f"This is iteration {iteration}")
iteration+=1
return (iteration,)
r = tf.while_loop(c, print_fun, [iteration], parallel_iterations=10)
或
i = tf.constant(0)
c = lambda i: tf.less(i, 1000)
b = lambda i: (tf.add(i, 1),)
r = tf.while_loop(c, b, [i])
是什么导致tf.while_loop
无法并行化循环?
此外,如果维护Tensorflow文档的任何人都看到此页面,则他/她应在第一个示例中修复该错误。参见讨论here。
谢谢。
答案 0 :(得分:1)
parallel_iterations
在急切模式下运行没有任何意义,但是您始终可以使用tf.function
装饰器并获得显着的加速。可以在这张图片中看到:running times
您可以像这样用tf.while_loop
包裹tf.function
@tf.function
def run_graph():
iteration = tf.constant(0)
r = tf.while_loop(c, print_fun, [iteration], parallel_iterations=4)
,然后在需要时致电run_graph
。