tf.while_loop中的副作用

时间:2018-04-06 06:39:40

标签: python tensorflow while-loop

我目前很难理解tensorflow是如何工作的,我觉得python界面有点模糊。

我最近尝试在tf.while_loop中运行一个简单的print语句,还有很多东西对我来说还不清楚:

import tensorflow as tf

nb_iter = tf.constant(value=10)
#This solution does not work at all
#nb_iter = tf.get_variable('nb_iter', shape=(1), dtype=tf.int32, trainable=False)
i = tf.get_variable('i', shape=(), trainable=False,
                     initializer=tf.zeros_initializer(), dtype=nb_iter.dtype)

loop_condition = lambda i: tf.less(i, nb_iter)
def loop_body(i):
    tf.Print(i, [i], message='Another iteration')
    return [tf.add(i, 1)]

i = tf.while_loop(loop_condition, loop_body, [i])

initializer_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(initializer_op)
    res = sess.run(i)
    print('res is now {}'.format(res))

请注意,如果我使用

初始化nb_iter
nb_iter = tf.get_variable('nb_iter', shape=(1), dtype=tf.int32, trainable=False)

我收到以下错误:

  

ValueError:Shape必须为0级,但是'而/ LoopCond'   (op:' LoopCond')输入形状:[1]。

当我尝试使用' i'索引张量的索引(此处未显示示例),然后我得到以下错误

  

alueError:操作' while / strided_slice'被标记为没有   可提取。

有人可以给我一个文档来解释tf.while_loop在与tf.Variables一起使用时如何工作,以及是否可以在循环中使用side_effects(如print),以及使用循环变量索引张量?< / p>

提前感谢您的帮助

1 个答案:

答案 0 :(得分:1)

我的第一个例子实际上有很多问题:

如果操作员没有副作用(即i = tf.Print()),则不执行打印。

如果布尔值是标量,则它是秩0张量,而不是秩-1张量。 ...

以下是有效的代码:

import tensorflow as tf

#nb_iter = tf.constant(value=10)
#This solution does not work at all
nb_iter = tf.get_variable('nb_iter', shape=(), dtype=tf.int32, trainable=False,
                          initializer=tf.zeros_initializer())
nb_iter = tf.add(nb_iter,10)
i = tf.get_variable('i', shape=(), trainable=False,
                     initializer=tf.zeros_initializer(), dtype=nb_iter.dtype)
v = tf.get_variable('v', shape=(10), trainable=False,
                     initializer=tf.random_uniform_initializer, dtype=tf.float32)

loop_condition = lambda i: tf.less(i, nb_iter)
def loop_body(i):
    i = tf.Print(i, [v[i]], message='Another vector element: ')
    return [tf.add(i, 1)]

i = tf.while_loop(loop_condition, loop_body, [i])

initializer_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(initializer_op)
    res = sess.run(i)
    print('res is now {}'.format(res))

输出:

Another vector element: [0.203766704]
Another vector element: [0.692927241]
Another vector element: [0.732221603]
Another vector element: [0.0556482077]
Another vector element: [0.422092319]
Another vector element: [0.597698212]
Another vector element: [0.92387116]
Another vector element: [0.590101123]
Another vector element: [0.741415381]
Another vector element: [0.514917374]
res is now 10