在Pytorch中计算欧几里得范数。.难以理解实现

时间:2018-08-23 13:19:07

标签: python pytorch euclidean-distance

我已经看到了另一个StackOverflow线程,它在讨论用于计算欧几里得范数的各种实现,而我很难理解为什么/如何实现特定的实现。

可在MMD指标的实现中找到该代码:https://github.com/josipd/torch-two-sample/blob/master/torch_two_sample/statistics_diff.py

这是一些开始的样板:

import torch
sample_1, sample_2 = torch.ones((10,2)), torch.zeros((10,2))

然后下一部分是我们从上面的代码中提取的内容。.我不确定为什么将这些样本连接在一起。

sample_12 = torch.cat((sample_1, sample_2), 0)
distances = pdist(sample_12, sample_12, norm=2)

,然后传递给pdist函数:

def pdist(sample_1, sample_2, norm=2, eps=1e-5):
    r"""Compute the matrix of all squared pairwise distances.
    Arguments
    ---------
    sample_1 : torch.Tensor or Variable
        The first sample, should be of shape ``(n_1, d)``.
    sample_2 : torch.Tensor or Variable
        The second sample, should be of shape ``(n_2, d)``.
    norm : float
        The l_p norm to be used.
    Returns
    -------
    torch.Tensor or Variable
        Matrix of shape (n_1, n_2). The [i, j]-th entry is equal to
        ``|| sample_1[i, :] - sample_2[j, :] ||_p``."""

这是计算的重点

    n_1, n_2 = sample_1.size(0), sample_2.size(0)
    norm = float(norm)
    if norm == 2.:
        norms_1 = torch.sum(sample_1**2, dim=1, keepdim=True)
        norms_2 = torch.sum(sample_2**2, dim=1, keepdim=True)
        norms = (norms_1.expand(n_1, n_2) +
             norms_2.transpose(0, 1).expand(n_1, n_2))
        distances_squared = norms - 2 * sample_1.mm(sample_2.t())
        return torch.sqrt(eps + torch.abs(distances_squared))

为什么要以这种方式计算欧几里得规范,我感到茫然。任何见识将不胜感激

1 个答案:

答案 0 :(得分:4)

让我们逐步遍历此代码块。欧几里得距离的定义,即L2范数是

enter image description here

让我们考虑最简单的情况。我们有两个示例

enter image description here

样本var y0= d3.scaleBand().paddingInner(0.1).range([0, canvasHeight]);有两个向量a[a00, a01]。样本[a10, a11]相同。首先计算b

norm

现在我们得到

enter image description here

接下来,我们有n1, n2 = a.size(0), b.size(0) # here both n1 and n2 have the value 2 norm1 = torch.sum(a**2, dim=1) norm2 = torch.sum(b**2, dim=1) norms_1.expand(n_1, n_2)

enter image description here

请注意,norms_2.transpose(0, 1).expand(n_1, n_2)已换位。两者之和为b

enter image description here

norm,这是两个矩阵的乘积。

enter image description here

因此,手术后

sample_1.mm(sample_2.t())

你得到

enter image description here

最后,最后一步是获取矩阵中每个元素的平方根。