tf.while_loop与tf.concat的奇怪行为

时间:2020-05-12 01:40:53

标签: python tensorflow

我正在尝试通过级联元素来构建矩阵,并将该矩阵作为tf.while_loop中的输入重用,tf.concat的输出应该是形状(x+1,y)的张量,但它返回(x + 1)个(y,)形状的张量,导致两个结构不具有相同的嵌套结构。

以下是重新产生问题的代码段:

@tf.function(autograph=False)
def cond(a):
    if a.shape[0] != None:
        return tf.less(a.shape[0],5)
    else:
        return tf.constant(True)


def loop():
    ans = tf.while_loop(cond,
                        lambda a: tf.concat([a,tf.zeros((1,5),dtype=tf.float64)],axis=0),
                        loop_vars=[tf.zeros((1,5),dtype=tf.float64)],
                        shape_invariants=[tf.TensorShape((None,5))],parallel_iterations=1)

    return ans

在循环中添加@tf.function装饰器只会使其冻结,而不会给出错误消息或任何类似信息。

这是没有@tf.function装饰器的错误消息:

ValueError: The two structures don't have the same nested structure.

First structure: type=list str=[TensorSpec(shape=(1, 5), dtype=tf.float64, name=None)]

Second structure: type=list str=[<tf.Tensor: shape=(5,), dtype=float64, numpy=array([0., 0., 0., 0., 0.])>, <tf.Tensor: shape=(5,), dtype=float64, numpy=array([0., 0., 0., 0., 0.])>]

我不太确定为什么函数会返回多个张量的列表,而不是返回具有所需形状的张量的列表。

如果我将tf.concat中的轴更改为1,即使我不允许更改形状的第二个分量,该函数也不会出现任何问题。

对于任何可能会发生这种情况的解释,我将不胜感激。

1 个答案:

答案 0 :(得分:0)

我认为这在某些TensorFlow版本中并未发生,因此您可能会说这是一个错误,但是从技术上讲该错误是正确的(尽管错误消息几乎无法向您指出正确的方向)。 loop_varsshape_invariants参数都是一个列表,但是body函数返回的值是张量,而不是列表,因此结果的结构不匹配。您只需将其转换为列表即可解决该问题:

def loop():
    ans = tf.while_loop(cond,
                        lambda a: [tf.concat([a, tf.zeros((1, 5), dtype=tf.float64)], axis=0)],
                        loop_vars=[tf.zeros((1, 5), dtype=tf.float64)],
                        shape_invariants=[tf.TensorShape((None, 5))], parallel_iterations=1)
    return ans

除此之外,您的cond函数有点奇怪,您通常不需要检查某些形状属性是否为None,只需使用tf.shape即可获取实际的张量的形状:

@tf.function(autograph=False)
def cond(a):
    return tf.shape(a)[0] < 5

在TensorFlow 2.2.0中,这给我一个关于每次调用都被追溯的函数的警告,原因是参数a的形状每次都不同,但是可以避免它传递{{1 }}到tf.function

在任何情况下,如果可以使用tf.TensorArray来累积结果并仅在末尾进行连接,则效果会更好:

experimental_relax_shapes=True

尽管在这种情况下,即使使用import tensorflow as tf @tf.function(autograph=False) def cond(ta, i): return i < 5 def loop(): # Fixed size is better, use dynamic_size=True otherwise ta = tf.TensorArray(tf.float64, size=5, element_shape=(1, 5)) # Initial element ta = ta.write(0, tf.zeros((1, 5), dtype=tf.float64)) ta, _ = tf.while_loop(cond, lambda ta, i: [ta.write(i, tf.zeros((1, 5), dtype=tf.float64)), i + 1], loop_vars=[ta, 1], shape_invariants=[None, tf.TensorShape(())], parallel_iterations=1) ans = ta.concat() return ans a = loop() ,跟踪警告也不会消失。

相关问题