我需要在tf.while_loop()中获取张量的中间值,但是,它只给我最后返回的值。
例如,我有一个变量x,它有3个页面,其尺寸为3 * 2 * 4。现在我想要获取每个页面一次并计算每页的总和,页面总和,平均值,最大值和最小值。然后我定义条件和body函数,并希望使用tf.while_loop()来计算所需的结果。源代码如下所示。
import tensorflow as tf
x = tf.constant([[[41, 8, 48, 82],
[9, 56, 67, 23]],
[[95, 89, 44, 54],
[11, 33, 29, 1]],
[[34, 9, 5, 70],
[14, 35, 18, 17]]], dtype=tf.int32)
def cond(out, count, x):
return count < 3
def body(out, count, x):
outTemp = tf.slice(x, [count, 0, 0], [1, -1, -1])
count += 1
outPack = tf.unpack(out)
outPack[0] += tf.reduce_sum(outTemp)
outPack[1] = tf.reduce_sum(outTemp)
outPack[2] = tf.reduce_mean(outTemp)
outPack[3] = tf.reduce_max(outTemp)
outPack[4] = tf.reduce_min(outTemp)
out = tf.pack(outPack)
return out, count, x
out = tf.Variable(tf.constant([0, 0, 0, 0, 0])) # total sum, page sum, mean, max, min
count = tf.Variable(tf.constant(0))
result = tf.while_loop(cond, body, [out, count, x])
init = tf.initialize_all_variables()
with tf.Session() as sess:
sess.run(init)
print(sess.run(x))
print(sess.run(result)[0])
当我运行程序时,它只给我上次返回的值,我只能得到最后一页的结果。 所以问题是,如何获得每个页面的结果以及如何从tf.while_loop()中获取中间值?
谢谢。
答案 0 :(得分:1)
获得&#34;中间值&#34;在任何变量中,您可以简单地使用tf.Print op,它实际上是一种身份操作,在评估上述变量时会产生打印相关消息的副作用。
举个例子,
int
可以放在您希望报告值的任何行中。