现在我有一个张量farthest_idxs的大小(批处理,特征)=(24,32)。我还有一个大小(点,批处理,特征)的张量Nearest_idxs =(1024、24、1023)。对于一个点p和一个样本s(即,大小为1x1023的near_idxs [p,s ,:]),我想在此向量中找到在farthest_idxs [s,:](大小为1x32)中的第一个元素,并返回一个记录结果的矩阵(大小为24x1024)。有什么有效的方法可以实现吗?
这是我的代码,这是一种无效的实现方式。
def nearest_indices(self, relation, farthest_idxs):
'''Generate the nearest indices
return:
[B, N] matrix
'''
device = relation.device
nearest_value, nearest_idxs = torch.topk(relation, k=1023, dim=2, largest=False, sorted=True)
print('nearest idxs', nearest_idxs)
nearest_idxs = nearest_idxs.transpose(0,1) # 1024x24x1023
print('transposed nearest_idxs', nearest_idxs.shape)
N, B, P = nearest_idxs.shape
upsample_idxs = torch.zeros((B, N), dtype=torch.long).to(device)
for n in range(N):
for b in range(B):
for p in range(P):
if nearest_idxs[n, b, p] in farthest_idxs[b,:]:
upsample_idxs[b, n] = nearest_idxs[n, b, p]
break
print(upsample_idxs.shape)