TensorFlow:如何在tf.while_loop()中获取变量的中间值?

时间:2016-09-27 02:50:41

标签: tensorflow

我需要在tf.while_loop()中获取张量的中间值,但是,它只给我最后返回的值。

例如,我有一个变量x,它有3个页面,其尺寸为3 * 2 * 4。现在我想要获取每个页面一次并计算每页的总和,页面总和,平均值,最大值和最小值。然后我定义条件和body函数,并希望使用tf.while_loop()来计算所需的结果。源代码如下所示。

import tensorflow as tf    
x = tf.constant([[[41, 8, 48,  82],
                      [9, 56, 67,  23]],
                     [[95, 89, 44,  54],
                      [11, 33, 29,  1]],
                     [[34,  9,  5,  70],
                      [14, 35, 18,  17]]], dtype=tf.int32)

def cond(out, count, x):
  return count < 3

def body(out, count, x):
  outTemp = tf.slice(x, [count, 0, 0], [1, -1, -1])
  count += 1
  outPack = tf.unpack(out)
  outPack[0] += tf.reduce_sum(outTemp)
  outPack[1] = tf.reduce_sum(outTemp)
  outPack[2] = tf.reduce_mean(outTemp)
  outPack[3] = tf.reduce_max(outTemp)
  outPack[4] = tf.reduce_min(outTemp)
  out = tf.pack(outPack)
  return out, count, x

out = tf.Variable(tf.constant([0, 0, 0, 0, 0]))  # total sum, page sum, mean, max, min
count = tf.Variable(tf.constant(0))
result = tf.while_loop(cond, body, [out, count, x])

init = tf.initialize_all_variables()
with tf.Session() as sess:
  sess.run(init)
  print(sess.run(x))
  print(sess.run(result)[0])

当我运行程序时,它只给我上次返回的值,我只能得到最后一页的结果。 所以问题是,如何获得每个页面的结果以及如何从tf.while_loop()中获取中间值?

谢谢。

1 个答案:

答案 0 :(得分:1)

获得&#34;中间值&#34;在任何变量中,您可以简单地使用tf.Print op,它实际上是一种身份操作,在评估上述变量时会产生打印相关消息的副作用。

举个例子,

int

可以放在您希望报告值的任何行中。