如何在tensorflow中使用tf.while_loop()

时间:2016-05-25 15:09:36

标签: python tensorflow

这是一个通用的问题。我发现在张量流中,在我们构建图形之后,将数据提取到图形中,图形的输出是张量。但在很多情况下,我们需要根据此输出(tensor)进行一些计算,这在tensorflow中是不允许的。

例如,我试图实现一个RNN,它根据数据自身属性循环时间。也就是说,我需要使用tensor来判断是否应该停止(我不使用dynamic_rnn,因为在我的设计中,rnn是高度自定义的)。我发现tf.while_loop(cond,body.....)可能是我实施的候选人。但官方教程太简单了。我不知道如何在“身体”中添加更多功能。谁能给我一些更复杂的例子?

此外,在这种情况下,如果将来的计算基于张量输出(例如:基于输出标准的RNN停止),这是非常常见的情况。是否有优雅的方式或更好的方式而不是动态图?

1 个答案:

答案 0 :(得分:50)

什么阻止您向身体添加更多功能?您可以在正文中构建您喜欢的任何复杂计算图形,并从封闭图形中获取您喜欢的任何输入。此外,在循环之外,您可以随意返回任何输出。从“whatevers”的数量可以看出,TensorFlow的控制流原语在构思时考虑了很多。下面是另一个“简单”的例子,如果有帮助的话。

import tensorflow as tf
import numpy as np

def body(x):
    a = tf.random_uniform(shape=[2, 2], dtype=tf.int32, maxval=100)
    b = tf.constant(np.array([[1, 2], [3, 4]]), dtype=tf.int32)
    c = a + b
    return tf.nn.relu(x + c)

def condition(x):
    return tf.reduce_sum(x) < 100

x = tf.Variable(tf.constant(0, shape=[2, 2]))

with tf.Session():
    tf.global_variables_initializer().run()
    result = tf.while_loop(condition, body, [x])
    print(result.eval())