运行Tensorflow“for loop”的规范方法是什么?
具体来说,假设我们有一些body
函数,它不依赖于循环迭代,但必须运行n
次。
有人可能会认为一个好的方法可能是在tf.while_loop
内运行这样的:
def body(x):
return ...
def while_body(i,x):
return i+1, body(x)
i, x = tf.while_loop(lambda i: tf.less(i, n), while_body, [tf.constant(0),x])
事实上,这正是这个问题中得分最高的答案所暗示的:
How can I run a loop with a tensor as its range? (in tensorflow)
然而,tf.while_loop
docs具体说
对于正确的程序,while_loop应该为任何parallel_iterations返回相同的结果> 0
如果你在体内放置一个计数器,那么似乎违反了这个条件。所以似乎必须有一种不同的方式来设置“for循环”。
此外,即使没有明确的错误,这样做似乎会在迭代之间创建依赖关系,这意味着我不认为它们会并行运行。
答案 0 :(得分:0)
经过一番调查,似乎上面使用的tf.while_loop
成语很常见。或者,可以使用tf.scan
:
def body( x ):
return ...
def scan_body( previous_output, iteration ):
return body( ... )
x = tf.scan( scan_body, tf.range(n), initializer = [x] )
虽然我不知道从性能的角度来看是否更适合。请注意,我们必须包装body
函数以接受先前的输出。