如何在pytorch中的维度上选择单个索引?

时间:2018-12-31 10:23:44

标签: pytorch tensor

假设我有一个形状为sequences的张量[8, 12, 2]。现在,我想为每个第一个维度选择一个张量,以产生形状为[8, 2]的张量。尺寸1的选择由存储在形状为indices的长张量[8]中的索引指定。

我尝试了此操作,但是它为indices中的每个第一维选择了sequences中的每个索引,而不是仅选择一个。

sequences[:, indices]

如何在没有缓慢且丑陋的for循环的情况下进行此查询?

3 个答案:

答案 0 :(得分:1)

您似乎正在寻找torch.gather

torh.gather(sequences, dim=1, index=indices)

答案 1 :(得分:0)

torch.index_select比torch.gather更容易解决您的问题,因为您不必调整indeces的尺寸。指数必须是张量。对于你的情况

indeces = [0,2]
a = torch.rand(size=(3,3,3))
torch.index_select(a,dim=1,index=torch.tensor(indeces,dtype=torch.long))

答案 2 :(得分:0)

torch.gather应该可行,但是您需要先将索引张量转换为

  • unsqueeze,以匹配输入张量的维数
  • repeat_interleave,以匹配上一个尺寸的大小

以下是根据您的描述的示例:

# original indices dimension [8]
# after first unsueeze, dimension is [8, 1]
indices = torch.unsqueeze(indices, 1)

# after second unsueeze, dimension is [8, 1, 1]
indices = torch.unsqueeze(indices, 2)

# after repeat, dimension is [8, 1, 2]
indices = torch.repeat_interleave(indices, 2, dim=2)

# now you have the right dimension for torch.gather
# don't forget to squeeze the redundant dimension
# result has dimension [8, 2] 
result = torch.gather(sequences, 1, indices).squeeze()