假设我有一个形状为sequences
的张量[8, 12, 2]
。现在,我想为每个第一个维度选择一个张量,以产生形状为[8, 2]
的张量。尺寸1的选择由存储在形状为indices
的长张量[8]
中的索引指定。
我尝试了此操作,但是它为indices
中的每个第一维选择了sequences
中的每个索引,而不是仅选择一个。
sequences[:, indices]
如何在没有缓慢且丑陋的for
循环的情况下进行此查询?
答案 0 :(得分:1)
您似乎正在寻找torch.gather
:
torh.gather(sequences, dim=1, index=indices)
答案 1 :(得分:0)
torch.index_select比torch.gather更容易解决您的问题,因为您不必调整indeces的尺寸。指数必须是张量。对于你的情况
indeces = [0,2]
a = torch.rand(size=(3,3,3))
torch.index_select(a,dim=1,index=torch.tensor(indeces,dtype=torch.long))
答案 2 :(得分:0)
这torch.gather
应该可行,但是您需要先将索引张量转换为
unsqueeze
,以匹配输入张量的维数repeat_interleave
,以匹配上一个尺寸的大小以下是根据您的描述的示例:
# original indices dimension [8]
# after first unsueeze, dimension is [8, 1]
indices = torch.unsqueeze(indices, 1)
# after second unsueeze, dimension is [8, 1, 1]
indices = torch.unsqueeze(indices, 2)
# after repeat, dimension is [8, 1, 2]
indices = torch.repeat_interleave(indices, 2, dim=2)
# now you have the right dimension for torch.gather
# don't forget to squeeze the redundant dimension
# result has dimension [8, 2]
result = torch.gather(sequences, 1, indices).squeeze()