对于给定尺寸的张量
data.shape [B,N,F]
和indices.shape [N,K]
,
其中K
是每个点N的“邻居”的索引,
是否存在一种简单的方法来收集邻居,使得output.shape [B,N,K,F]
吗?
data = [[[1,2,3],[2,3,4],[3,4,5]],[[3,4,5],[6,7,8],[2,3,4]] # Shape 2,3,3
indices = [[1,2],[1,0],[0,0]] # Shape 3,2
output = [
[[[2,3,4],[3,4,5]],[[2,3,4],[1,2,3]],[[1,2,3],[1,2,3]]],
[[[6,7,8],[2,3,4]],[[6,7,8],[3,4,5]],[[3,4,5],[3,4,5]]]
]
例如,每批中的第一个点“关联”到点[1,2]
,因此output[0][0] = [[2,3,4],[3,4,5]]
是输入的第一批中的点1,2
。
我的尝试:
batched_indices = indices.unsqueeze(0).unsqueeze(-1).repeat(batch_size,1,1,feature_size)
data_neighs = data.unsqueeze(2).repeat(1,1,num_neighs,1)
output = torch.gather(data_neighs,batched_indices,dim=1)