如果我的张量为
value = torch.tensor([
[[0, 0, 0], [1, 1, 1]],
[[2, 2, 2], [3, 3, 3]],
])
基本上形状为(2,2,3)
。
现在说我是否有index = [1, 0]
,这意味着我想参加:
# row 1 of [[0, 0, 0], [1, 1, 1]], giving me: [1, 1, 1]
# row 0 of [[2, 2, 2], [3, 3, 3]], giving me: [2, 2, 2]
这样最终输出:
output = torch.tensor([[1, 1, 1], [2, 2, 2]])
是否有矢量化的方法来实现这一目标?
答案 0 :(得分:0)
您可以使用高级索引。
我找不到关于此的好的pytorch文档,但我相信它的作用与numpy相同,因此这里是numpy的document about indexing。
import torch
value = torch.tensor([
[[0, 0, 0], [1, 1, 1]],
[[2, 2, 2], [3, 3, 3]],
])
index = [1, 0]
i = range(0,2)
result = value[i, index]
# same as result = value[i, index, :]
print(result)