tf.scan掩盖了Tensor的形状

时间:2018-10-28 14:11:19

标签: python tensorflow

我想在tf.scan循环中进行一些计算。我的问题是,tf.slice的形状似乎在tf.scan循环中是未知的,或者与tf.scan循环之外的形状相比还是未知的。我想出的MWE看起来像这样:

def compute(x, _):
    i = x[0]
    a = x[1]
    print(tf.slice(a, [0], [i]).shape) # First line
    return (i + 1, a)

a = tf.constant([1, 2, 3, 4, 5])
with tf.Session() as sess:
    print(tf.slice(a, [0], [tf.constant(1)]).shape) # Second line
    x = tf.scan(
        compute,
        tf.zeros(a.shape),
        initializer=(tf.constant(1, tf.int32), a)
    )

返回以下两行:

(1,)
(?,)

为什么tf.slice(?,)函数中返回compute?有没有办法将切片铸成特定形状?

1 个答案:

答案 0 :(得分:2)

tf.slice返回形状为(?,)的张量的原因很简单。此方法具有以下签名:

tf.slice(input_, begin, size, name=None)

重写调用tf.slice(a, [0], [i])的Python方法是a[0:i]

tf.scan的每次迭代中,i内的张量compute是从先前迭代的结果派生的,可以改变并且具有任意值。在编译期间,TensorFlow无法推断切片的确切形状并返回(?,)

在您的情况下,每次迭代i都会增加,并且您会得到各种形状的切片。但是,如果您不更新i并具有相同大小的切片,则输出仍将是(?,),因为i的值来自外部迭代和先前的迭代。

如果您对切片的形状充满信心,那么在运行时就不必铸造切片的形状。例如,根据您的情况,每次迭代的形状都不同。 tf.scan的主要规则是输出和初始化张量具有一致的形状,否则将抛出ValueError