火炬以2D张量收集3D张量作为“聚集图”

时间:2020-09-09 13:49:58

标签: python pytorch

对于给定尺寸的张量
data.shape [B,N,F]indices.shape [N,K]
其中K是每个点N的“邻居”的索引,
是否存在一种简单的方法来收集邻居,使得output.shape [B,N,K,F]吗?

data = [[[1,2,3],[2,3,4],[3,4,5]],[[3,4,5],[6,7,8],[2,3,4]] # Shape 2,3,3
indices = [[1,2],[1,0],[0,0]] # Shape 3,2

output = [
         [[[2,3,4],[3,4,5]],[[2,3,4],[1,2,3]],[[1,2,3],[1,2,3]]],
         [[[6,7,8],[2,3,4]],[[6,7,8],[3,4,5]],[[3,4,5],[3,4,5]]]
         ]

例如,每批中的第一个点“关联”到点[1,2],因此output[0][0] = [[2,3,4],[3,4,5]]是输入的第一批中的点1,2

我的尝试:

batched_indices = indices.unsqueeze(0).unsqueeze(-1).repeat(batch_size,1,1,feature_size)
data_neighs = data.unsqueeze(2).repeat(1,1,num_neighs,1)
output = torch.gather(data_neighs,batched_indices,dim=1)

0 个答案:

没有答案