Canonical Tensorflow“for loop”

时间:2018-01-23 07:48:06

标签: tensorflow

运行Tensorflow“for loop”的规范方法是什么?

具体来说,假设我们有一些body函数,它不依赖于循环迭代,但必须运行n次。

有人可能会认为一个好的方法可能是在tf.while_loop内运行这样的:

def body(x):
    return ...

def while_body(i,x):
    return i+1, body(x) 

i, x = tf.while_loop(lambda i: tf.less(i, n), while_body, [tf.constant(0),x])

事实上,这正是这个问题中得分最高的答案所暗示的:

How can I run a loop with a tensor as its range? (in tensorflow)

然而,tf.while_loop docs具体说

  

对于正确的程序,while_loop应该为任何parallel_iterations返回相同的结果> 0

如果你在体内放置一个计数器,那么似乎违反了这个条件。所以似乎必须有一种不同的方式来设置“for循环”。

此外,即使没有明确的错误,这样做似乎会在迭代之间创建依赖关系,这意味着我不认为它们会并行运行。

1 个答案:

答案 0 :(得分:0)

经过一番调查,似乎上面使用的tf.while_loop成语很常见。或者,可以使用tf.scan

def body( x ):
    return ...

def scan_body( previous_output, iteration ):
    return body( ... )

x = tf.scan( scan_body, tf.range(n), initializer = [x] )

虽然我不知道从性能的角度来看是否更适合。请注意,我们必须包装body函数以接受先前的输出。