如何在PyTorch中计算点集和线之间的成对距离?

时间:2019-11-01 13:28:08

标签: python pytorch torch

点集A是一个Nx3矩阵,从两个BC的两个Mx3大小相同的点集中,我们可以得到第BC行在它们之间。现在,我要计算从A中的每个点到BC中的每条线的距离。 BMx3,而CMx3,则这些线是从具有相应行的点开始的,因此BC是一个Mx3矩阵。基本方法计算如下:

D = torch.zeros((N, M), dtype=torch.float32)
for i in range(N):
    p = A[i]  # 1x3
    for j in range(M):
        p1 = B[j] # 1x3
        p2 = C[j] # 1x3
        D[i,j] = torch.norm(torch.cross(p1 - p2, p - p1)) / torch.norm(p1 - p2) 

有没有更快的方法来完成这项工作?谢谢。

1 个答案:

答案 0 :(得分:3)

您可以通过执行以下操作删除for循环(除非MN很小,否则它应以内存为代价加快速度):

diff_B_C = B - C
diff_A_C = A[:, None] - C
norm_lines = torch.norm(diff_B_C, dim=-1)
cross_result = torch.cross(diff_B_C[None, :].expand(N, -1, -1), diff_A_C, dim=-1)
norm_cross = torch.norm(cross_result, dim=-1)
D = norm_cross / norm_lines

当然,您不需要逐步进行操作。我只是想弄清楚变量名。

注意:如果您未向dim提供torch.cross,它将使用第一个dim=3,如果{{1 }}(来自docs):

  

如果未指定dim,则默认为尺寸为3的第一个尺寸。

如果您想知道,可以检查here为什么选择N=3而不是expand