沿特定维度和特定通道索引整个张量

时间:2021-04-23 13:59:48

标签: python numpy tensorflow indexing pytorch

假设我们有一个张量 A,其维度为 dim(A)=[i, j, k=6, u, v]。现在我们有兴趣使用 channels=[0:3] 获得 k 维的整个张量。我知道我们可以通过这种方式得到它:

B = A[:, :, 0:3, :, :]

现在我想知道是否有更好的“pythonic”方法可以在不进行这种次优索引的情况下实现相同的结果。我的意思是类似的东西。

B = subset(A, dim=2, index=[0, 1, 2])

无论在哪个框架中,即pytorch、tensorflow、numpy等

非常感谢

1 个答案:

答案 0 :(得分:1)

在 numpy 中,您可以使用 take 方法:

B = A.take([0,1,2], axis=2)

在 TensorFlow 中,没有比使用传统方法更简洁的方法了。使用 tf.slice 会非常冗长:

B = tf.slice(A,[0,0,0,0,0],[-1,-1,3,-1,-1])

您可以潜在地使用 take 的实验版本(自 TF 2.4 起):

B = tf.experimental.numpy.take(A, [0,1,2], axis=2)

在 PyTorch 中,您可以使用 index_select

torch.index_select(A, dim=2, index=torch.tensor([0,1,2]))

请注意,您可以使用 ellipsis 跳过明确列出第一个(或最后一个)维度:

# Both are equivalent in that case
B = A[..., 0:3, :, :]
B = A[:, :, 0:3, ...]
相关问题