使用tf.while_loop()将TensorFlow插入无限循环

时间:2016-06-01 14:04:32

标签: python tensorflow

重现的步骤

我正在使用TensorFlow来实现需要使用tf.while_loop()

的网络
import tensorflow as tf
import numpy as np
class model(object):
    def __init__(self):
        self.argmax_ep_gate_array = [ tf.placeholder(tf.int32, [None]) for _ in range(10)]
        argmax_ep_gate_array_concat = tf.concat(0, self.argmax_ep_gate_array)
        story_len = tf.constant(7)
        starter = tf.constant(0)
        z = []
        def body(hops):
            hops = tf.add(hops,1)
            z.append(hops)
            return hops
        def condition(hops):
            return tf.logical_and(tf.less(tf.gather(argmax_ep_gate_array_concat, hops),story_len),tf.less(hops,tf.constant(20)))

        self.gate_index = tf.while_loop(condition,body,[starter])
        self.z=tf.concat(0,z)

    def step(self, sess):
        feed={}
        for i in range(10):
            feed[self.argmax_ep_gate_array[i].name]=[i]
        print (sess.run([self.gate_index,self.z],feed))
with tf.Session() as sess:
    while_loop = model()
    sess.run(tf.initialize_all_variables())
    while_loop.step(sess)

你有什么尝试?

我发现如果我想sess.run()body()中没有返回的任何变量,tensorflow会陷入无限循环。 上面的例子是微不足道的,但它揭示了一些东西。在实际情况中,我使用运行RNN的tf.while_loop()包含y = wx + b之类的东西,但在while循环之后不返回wb。在前向网络中,它工作正常。但是,如果我运行反向传播,程序将陷入无限循环。我想上面的代码重现了我的问题,因为反向传播确实需要修改wb。或者有什么方法可以解决这个问题吗?

1 个答案:

答案 0 :(得分:8)

TL; DR:您无法存储在循环体中创建的张量供以后使用,因为这会破坏有关循环结构的一些假设。

通常,condition()body()函数不得有副作用。 实际上,您的程序不太可能具有预期的行为:TensorFlow将执行body()函数一次,以构建必要的图形结构,因此z将只包含一个元素在运行model.__init__()

之后

相反,您必须使用z在循环体中逐步构造tf.concat()并将值作为循环变量生成:

starter = tf.constant(0)
z_initial = tf.constant([], dtype=tf.int32)

def body(hops, z_prev):
    hops = tf.add(hops, 1)
    z_next = tf.concat(0, [z_prev, tf.expand_dims(hops, 0)])
    return hops, z_next
def condition(hops, z):
    return tf.logical_and(tf.less(tf.gather(
        argmax_ep_gate_array_concat, hops), story_len), tf.less(hops, tf.constant(20)))

self.gate_index, self.z = tf.while_loop(condition,body,[starter, z_initial])