获取特定维度的变量索引的值

时间:2020-03-09 07:38:10

标签: python numpy pytorch

如果我的张量为

value = torch.tensor([

    [[0, 0, 0], [1, 1, 1]],
    [[2, 2, 2], [3, 3, 3]],
])

基本上形状为(2,2,3)

现在说我是否有index = [1, 0],这意味着我想参加:

# row 1 of [[0, 0, 0], [1, 1, 1]], giving me: [1, 1, 1]
# row 0 of [[2, 2, 2], [3, 3, 3]], giving me: [2, 2, 2]

这样最终输出:

output = torch.tensor([[1, 1, 1], [2, 2, 2]])

是否有矢量化的方法来实现这一目标?

1 个答案:

答案 0 :(得分:0)

您可以使用高级索引。
我找不到关于此的好的pytorch文档,但我相信它的作用与numpy相同,因此这里是numpy的document about indexing

import torch

value = torch.tensor([

    [[0, 0, 0], [1, 1, 1]],
    [[2, 2, 2], [3, 3, 3]],
])

index = [1, 0]
i = range(0,2)

result = value[i, index]
# same as result = value[i, index, :] 

print(result)