我想在tf.scan
循环中进行一些计算。我的问题是,tf.slice
的形状似乎在tf.scan
循环中是未知的,或者与tf.scan
循环之外的形状相比还是未知的。我想出的MWE看起来像这样:
def compute(x, _):
i = x[0]
a = x[1]
print(tf.slice(a, [0], [i]).shape) # First line
return (i + 1, a)
a = tf.constant([1, 2, 3, 4, 5])
with tf.Session() as sess:
print(tf.slice(a, [0], [tf.constant(1)]).shape) # Second line
x = tf.scan(
compute,
tf.zeros(a.shape),
initializer=(tf.constant(1, tf.int32), a)
)
返回以下两行:
(1,)
(?,)
为什么tf.slice
在(?,)
函数中返回compute
?有没有办法将切片铸成特定形状?
答案 0 :(得分:2)
tf.slice
返回形状为(?,)
的张量的原因很简单。此方法具有以下签名:
tf.slice(input_, begin, size, name=None)
重写调用tf.slice(a, [0], [i])
的Python方法是a[0:i]
。
在tf.scan
的每次迭代中,i
内的张量compute
是从先前迭代的结果派生的,可以改变并且具有任意值。在编译期间,TensorFlow无法推断切片的确切形状并返回(?,)
。
在您的情况下,每次迭代i
都会增加,并且您会得到各种形状的切片。但是,如果您不更新i
并具有相同大小的切片,则输出仍将是(?,)
,因为i
的值来自外部迭代和先前的迭代。
如果您对切片的形状充满信心,那么在运行时就不必铸造切片的形状。例如,根据您的情况,每次迭代的形状都不同。 tf.scan
的主要规则是输出和初始化张量具有一致的形状,否则将抛出ValueError
。