假设我们的张量x
具有未知的第一维(例如[?, 32, 32, 3]
),而另一张量i
实际上是标量。有没有一种很好的方法可以获得i
第一维分割的x
切片,例如,获得尺寸[32, 32, 3]
的张量?我是TensorFlow的新手,只能想出这个极其笨拙的解决方案。
index = tf.concat(0, [i, tf.constant([0, 0, 0], tf.int64)])
size = [1, x.get_shape()[1].value, x.get_shape()[2].value, x.get_shape()[3].value]
result = tf.unpack(tf.slice(x, index, size))[0]
答案 0 :(得分:3)
您可以利用-1
是tf.slice()
size
参数的特殊参数这一事实,这意味着“该维度中的所有剩余元素”。然后,假设i
是标量(而不是像你的代码片段中那样的长度为1的向量),你可以这样做:
result = tf.squeeze(tf.slice(x, tf.pack([index, 0, 0, 0]), [1, -1, -1, -1]), [0])
或者,您可以使用tf.gather()
从第0个维度上的张量中选择一个或多个切片。在这种情况下,i
必须是向量:
i = tf.expand_dims(i, 0) # Converts `i` to a vector if it is a scalar.
result = tf.squeeze(tf.gather(x, i), [0])
在这两种情况下,tf.squeeze()
op都会移除第0维以提供三维结果。