Tensorflow:张量上的while循环

时间:2018-11-26 14:22:25

标签: python tensorflow tensor

我正在尝试对张量的值应用while循环。例如,对于变量“ a”,我尝试逐渐增加张量的值,直到满足特定条件为止。但是,我不断收到此错误:

  

ValueError:“ while_12 / LoopCond”的形状必须为0级,但必须为3级   (操作:“ LoopCond”),输入形状为[3,1,1]。

a = array([[[0.76393723]],
       [[0.93270312]],
       [[0.08361106]]])

a = np.random.random((3,1,1))
a1 = tf.constant(np.float64(a))
i = tf.constant(np.float64(6.14))

c = lambda i: tf.less(i, a1)
b = lambda x: tf.add(x, 0.1)
r = tf.while_loop(c, b, [a1])

1 个答案:

答案 0 :(得分:4)

tf.while_loop()的第一个参数应返回标量(等级0的张量实际上是标量-这就是错误消息的含义)。在您的示例中,如果true张量中的所有数字均小于a1,则您可能希望使条件返回6.14。这可以通过tf.reduce_all()(逻辑与)和tf.reduce_any()(逻辑或)来实现。

该代码段对我有用:

tf.reset_default_graph()

a = np.random.random_integers(3, size=(3,2))
print(a)
# [[1 1]
#  [2 3]
#  [1 1]]

a1 = tf.constant(a)
i = 6

# condition returns True till any number in `x` is less than 6
condition = lambda x : tf.reduce_any(tf.less(x, i))
body      = lambda x : tf.add(x, 1)
loop = tf.while_loop(
    condition,
    body,
    [a1],
)

with tf.Session() as sess:
    result = sess.run(loop)
    print(result)
    # [[6 6]
    #  [7 8]
    #  [6 6]]
    # All numbers now are greater than 6