在TensorFlow中获取由另一个张量部分索引的切片有什么好方法?

时间:2016-03-21 14:01:33

标签: tensorflow

假设我们的张量x具有未知的第一维(例如[?, 32, 32, 3]),而另一张量i实际上是标量。有没有一种很好的方法可以获得i第一维分割的x切片,例如,获得尺寸[32, 32, 3]的张量?我是TensorFlow的新手,只能想出这个极其笨拙的解决方案。

index = tf.concat(0, [i, tf.constant([0, 0, 0], tf.int64)])
size = [1, x.get_shape()[1].value, x.get_shape()[2].value, x.get_shape()[3].value]
result = tf.unpack(tf.slice(x, index, size))[0]

1 个答案:

答案 0 :(得分:3)

您可以利用-1tf.slice() size参数的特殊参数这一事实,这意味着“该维度中的所有剩余元素”。然后,假设i是标量(而不是像你的代码片段中那样的长度为1的向量),你可以这样做:

result = tf.squeeze(tf.slice(x, tf.pack([index, 0, 0, 0]), [1, -1, -1, -1]), [0])

或者,您可以使用tf.gather()从第0个维度上的张量中选择一个或多个切片。在这种情况下,i必须是向量:

i = tf.expand_dims(i, 0)  # Converts `i` to a vector if it is a scalar.
result = tf.squeeze(tf.gather(x, i), [0])

在这两种情况下,tf.squeeze() op都会移除第0维以提供三维结果。