具有ValueError Shape的while_loop必须为0级,但对于'while / LoopCond'为2级

时间:2018-05-24 08:40:22

标签: python loops tensorflow rank

x=([1.,2.],
   [2.,1.])
xtensor = tf.convert_to_tensor(x)
A = xtensor
B = xtensor
def cond(now,pre):
   return (tf.greater(now,pre))
def body(now,pre):
   return pre,now
A,now = tf.while_loop(cond,body,[A,B])
with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
   A = sess.run(A)
   B = sess.run(B)
   now = sess.run(now)

实际上,可以忽略代码的用途,因为我简化了两个函数来提出这个问题并且仍然有相同的错误:

ValueError:Shape must be rank 0 but is rank 2 for 'while/LoopCond'(op:'LoopCond') with input shapes:[2,2].

我真的很困惑.....希望有人可以帮助我。非常感谢!

1 个答案:

答案 0 :(得分:1)

cond的条件函数(tf.while_loop())必须返回秩0的布尔张量(即形状[],即单个布尔值)。你的cond返回一个排名为2的布尔张量(因为tf.greater(now, pre)返回一个与now形状相同的张量,执行每元素比较。)