我正在尝试通过级联元素来构建矩阵,并将该矩阵作为tf.while_loop中的输入重用,tf.concat的输出应该是形状(x+1,y)
的张量,但它返回(x + 1)个(y,)
形状的张量,导致两个结构不具有相同的嵌套结构。
以下是重新产生问题的代码段:
@tf.function(autograph=False)
def cond(a):
if a.shape[0] != None:
return tf.less(a.shape[0],5)
else:
return tf.constant(True)
def loop():
ans = tf.while_loop(cond,
lambda a: tf.concat([a,tf.zeros((1,5),dtype=tf.float64)],axis=0),
loop_vars=[tf.zeros((1,5),dtype=tf.float64)],
shape_invariants=[tf.TensorShape((None,5))],parallel_iterations=1)
return ans
在循环中添加@tf.function
装饰器只会使其冻结,而不会给出错误消息或任何类似信息。
这是没有@tf.function
装饰器的错误消息:
ValueError: The two structures don't have the same nested structure.
First structure: type=list str=[TensorSpec(shape=(1, 5), dtype=tf.float64, name=None)]
Second structure: type=list str=[<tf.Tensor: shape=(5,), dtype=float64, numpy=array([0., 0., 0., 0., 0.])>, <tf.Tensor: shape=(5,), dtype=float64, numpy=array([0., 0., 0., 0., 0.])>]
我不太确定为什么函数会返回多个张量的列表,而不是返回具有所需形状的张量的列表。
如果我将tf.concat
中的轴更改为1,即使我不允许更改形状的第二个分量,该函数也不会出现任何问题。
对于任何可能会发生这种情况的解释,我将不胜感激。
答案 0 :(得分:0)
我认为这在某些TensorFlow版本中并未发生,因此您可能会说这是一个错误,但是从技术上讲该错误是正确的(尽管错误消息几乎无法向您指出正确的方向)。 loop_vars
和shape_invariants
参数都是一个列表,但是body函数返回的值是张量,而不是列表,因此结果的结构不匹配。您只需将其转换为列表即可解决该问题:
def loop():
ans = tf.while_loop(cond,
lambda a: [tf.concat([a, tf.zeros((1, 5), dtype=tf.float64)], axis=0)],
loop_vars=[tf.zeros((1, 5), dtype=tf.float64)],
shape_invariants=[tf.TensorShape((None, 5))], parallel_iterations=1)
return ans
除此之外,您的cond
函数有点奇怪,您通常不需要检查某些形状属性是否为None
,只需使用tf.shape
即可获取实际的张量的形状:
@tf.function(autograph=False)
def cond(a):
return tf.shape(a)[0] < 5
在TensorFlow 2.2.0中,这给我一个关于每次调用都被追溯的函数的警告,原因是参数a
的形状每次都不同,但是可以避免它传递{{1 }}到tf.function
。
在任何情况下,如果可以使用tf.TensorArray
来累积结果并仅在末尾进行连接,则效果会更好:
experimental_relax_shapes=True
尽管在这种情况下,即使使用import tensorflow as tf
@tf.function(autograph=False)
def cond(ta, i):
return i < 5
def loop():
# Fixed size is better, use dynamic_size=True otherwise
ta = tf.TensorArray(tf.float64, size=5, element_shape=(1, 5))
# Initial element
ta = ta.write(0, tf.zeros((1, 5), dtype=tf.float64))
ta, _ = tf.while_loop(cond,
lambda ta, i: [ta.write(i, tf.zeros((1, 5), dtype=tf.float64)), i + 1],
loop_vars=[ta, 1],
shape_invariants=[None, tf.TensorShape(())],
parallel_iterations=1)
ans = ta.concat()
return ans
a = loop()
,跟踪警告也不会消失。