我有这个张量:
tensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
我有这个索引张量:
tensor([0, 1])
而我想要得到的是根据dim 1和索引张量中相应索引的子张量,即:
tensor([[1, 2],
[7, 8]])
尝试使用 torch.gather() 函数和高级索引没有成功,有人可以帮忙吗?
答案 0 :(得分:1)
您隐式地使用了索引张量的每个值的索引。它们恰好与值相同。如果你想遍历第一级,张量的元素,你可以使用 torch.arange
来构造第一级索引。
import torch
from torch import tensor
t = tensor([[[1, 2],
[3, 4]],
[[5, 6],
[7, 8]]])
ix = tensor([0, 1])
ix0 = torch.arange(0, ix.shape.numel())
t[ix0, ix]
# returns:
tensor([[1, 2],
[7, 8]])