寻找K-最近邻及其实现

时间:2014-12-15 00:58:08

标签: matlab machine-learning knn

我正在使用具有欧几里德距离的KNN对简单数据进行分类。我已经看到了一个关于我想用MATLAB knnsearch函数完成的示例,如下所示:

load fisheriris 
x = meas(:,3:4);
gscatter(x(:,1),x(:,2),species)
newpoint = [5 1.45];
[n,d] = knnsearch(x,newpoint,'k',10);
line(x(n,1),x(n,2),'color',[.5 .5 .5],'marker','o','linestyle','none','markersize',10)

上面的代码采用了一个新点,即[5 1.45],并找到了与新点最接近的10个值。任何人都可以给我看一个MATLAB算法,详细解释knnsearch函数的作用吗?还有其他办法吗?

1 个答案:

答案 0 :(得分:38)

K-Nearest Neighbor(KNN)算法的基础是您有一个由N行和M列组成的数据矩阵,其中N是数据点的数量我们拥有,而M是每个数据点的维度。例如,如果我们将笛卡尔坐标放在数据矩阵中,则通常是N x 2N x 3矩阵。使用此数据矩阵,您可以提供查询点,并搜索此数据矩阵中距离此查询点最近的最近k个点。

我们通常使用查询与数据矩阵中其余点之间的欧几里德距离来计算距离。但是,也使用其他距离,如L1或City-Block / Manhattan距离。在此操作之后,您将具有N欧几里德或曼哈顿距离,其表示查询与数据集中的每个对应点之间的距离。找到这些后,您只需按升序对距离进行排序,然后检索数据集与查询之间距离最小的k点,即可搜索k最近的查询点。 / p>

假设您的数据矩阵存储在x中,newpoint是一个包含M列(即1 x M)的示例点,这是您的一般程序将遵循以下形式:

  1. 找出newpointx中每个点之间的欧几里德或曼哈顿距离。
  2. 按升序对这些距离进行排序。
  3. 返回距离k最近的x newpoint个数据点。
  4. 让我们慢慢地做每一步。


    第1步

    有人可能会这样做的一种方式可能就是for循环:

    N = size(x,1);
    dists = zeros(N,1);
    for idx = 1 : N
        dists(idx) = sqrt(sum((x(idx,:) - newpoint).^2));
    end
    

    如果你想实现曼哈顿距离,那就是:

    N = size(x,1);
    dists = zeros(N,1);
    for idx = 1 : N
        dists(idx) = sum(abs(x(idx,:) - newpoint));
    end
    

    dists将是N元素向量,其中包含xnewpoint中每个数据点之间的距离。我们在newpointx中的数据点之间进行逐元素减法,将差异平方,然后sum将它们全部放在一起。然后这个总和是平方根,这完成了欧几里德距离。对于曼哈顿距离,您将逐个元素执行元素,获取绝对值,然后将所有组件加在一起。这可能是最简单的实现,但它可能是效率最低的......尤其是对于更大的数据集和更大的数据维度。

    另一种可能的解决方案是复制newpoint并使此矩阵与x的大小相同,然后对此矩阵进行逐个元素的减法,然后对所有列进行求和每一行并做平方根。因此,我们可以这样做:

    N = size(x, 1);
    dists = sqrt(sum((x - repmat(newpoint, N, 1)).^2, 2));
    

    对于曼哈顿距离,你会这样做:

    N = size(x, 1);
    dists = sum(abs(x - repmat(newpoint, N, 1)), 2);
    

    repmat采用矩阵或向量,并在给定方向上重复它们一定次数。在我们的示例中,我们希望使用newpoint向量,并将此N次叠加在一起以创建N x M矩阵,其中每行为M个元素长。我们将这两个矩阵一起减去,然后对每个分量进行平方。完成此操作后,我们sum覆盖每行的所有列,最后获取所有结果的平方根。对于曼哈顿距离,我们进行减法,取绝对值然后求和。

    但是,在我看来,最有效的方法是使用bsxfun。这基本上是通过单个函数调用在我们讨论的复制。因此,代码就是这样:

    dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
    

    对我来说,这看起来更清洁,更重要。对于曼哈顿距离,您可以这样做:

    dists = sum(abs(bsxfun(@minus, x, newpoint)), 2);
    

    步骤#2

    现在我们有了距离,我们只是对它们进行排序。我们可以使用sort对距离进行排序:

    [d,ind] = sort(dists);
    

    d将包含按升序排序的距离,而ind会告诉您未排序数组中的每个值,它出现在已排序结果。我们需要使用ind,提取此向量的第一个k元素,然后使用ind索引到我们的x数据矩阵,以返回最接近的那些点newpoint

    步骤#3

    最后一步是返回最接近k的{​​{1}}个数据点。我们可以通过以下方式做到这一点:

    newpoint

    ind_closest = ind(1:k); x_closest = x(ind_closest,:); 应包含最接近ind_closest的原始数据矩阵x中的索引。具体而言,newpoint包含您需要从ind_closest中抽样的,以获得与x最近的点。 newpoint将包含这些实际数据点。


    为了您的复制和粘贴乐趣,这就是代码的样子:

    x_closest

    通过您的示例,让我们看看我们的代码:

    dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
    %// Or do this for Manhattan
    % dists = sum(abs(bsxfun(@minus, x, newpoint)), 2);
    [d,ind] = sort(dists);
    ind_closest = ind(1:k);
    x_closest = x(ind_closest,:);
    

    通过检查load fisheriris x = meas(:,3:4); newpoint = [5 1.45]; k = 10; %// Use Euclidean dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2)); [d,ind] = sort(dists); ind_closest = ind(1:k); x_closest = x(ind_closest,:); ind_closest,我们得到的是:

    x_closest

    如果您运行>> ind_closest ind_closest = 120 53 73 134 84 77 78 51 64 87 >> x_closest x_closest = 5.0000 1.5000 4.9000 1.5000 4.9000 1.5000 5.1000 1.5000 5.1000 1.6000 4.8000 1.4000 5.0000 1.7000 4.7000 1.4000 4.7000 1.4000 4.7000 1.5000 ,您会发现变量knnsearchn匹配。但是,变量ind_closest会将距离d返回到每个点newpoint,而不是实际数据点本身。如果你想要实际的距离,只需在我写的代码后执行以下操作:

    x

    请注意,上述答案仅使用一批dist_sorted = d(1:k); 个示例中的一个查询点。 KNN非常频繁地同时用于多个示例。假设我们想要在KNN中测试N个查询点。这将产生Q矩阵,对于每个示例或每个切片,我们返回k x M x Q最近点,其维度为k。或者,我们可以返回M最近点的 ID ,从而生成k矩阵。让我们计算两者。

    一种天真的方法是在循环中应用上面的代码并循环遍历每个示例。

    在我们分配Q x k矩阵并应用基于Q x k的方法将输出矩阵的每一行设置为数据集中bsxfun最近点的情况下,这样的事情就可以工作我们将像以前一样使用Fisher Iris数据集。我们还会保留与前一个示例中相同的维度,并且我将使用四个示例,因此kQ = 4

    M = 2

    虽然这很好,但我们可以做得更好。有一种方法可以有效地计算两组矢量之间的平方欧几里德距离。如果你想和曼哈顿一起做这件事,我会把它留作练习。咨询this blog,因为%// Load the data and create the query points load fisheriris; x = meas(:,3:4); newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5]; %// Define k and the output matrices Q = size(newpoints, 1); M = size(x, 2); k = 10; x_closest = zeros(k, M, Q); ind_closest = zeros(Q, k); %// Loop through each point and do logic as seen above: for ii = 1 : Q %// Get the point newpoint = newpoints(ii, :); %// Use Euclidean dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2)); [d,ind] = sort(dists); %// New - Output the IDs of the match as well as the points themselves ind_closest(ii, :) = ind(1 : k).'; x_closest(:, :, ii) = x(ind_closest(ii, :), :); end A矩阵,其中每一行都是维度点Q1 x M,其中M点和Q1是一个B矩阵,其中每一行也是具有Q2 x M点的维度点M,我们可以有效地计算距离矩阵Q2,其中行D(i, j)处的元素列i表示j的行iA的行j之间的距离,使用以下矩阵表达式:

    B

    因此,如果我们让nA = sum(A.^2, 2); %// Sum of squares for each row of A nB = sum(B.^2, 2); %// Sum of squares for each row of B D = bsxfun(@plus, nA, nB.') - 2*A*B.'; %// Compute distance matrix D = sqrt(D); %// Compute square root to complete calculation 成为查询点的矩阵,A是由原始数据组成的数据集,我们可以通过单独排序每一行来确定B最近点并确定每行最小的k位置。我们还可以使用它来自己检索实际点。

    因此:

    k

    我们看到我们使用逻辑来计算距离矩阵是相同的但是一些变量已经改变以适应这个例子。我们还使用%// Load the data and create the query points load fisheriris; x = meas(:,3:4); newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5]; %// Define k and other variables k = 10; Q = size(newpoints, 1); M = size(x, 2); nA = sum(newpoints.^2, 2); %// Sum of squares for each row of A nB = sum(x.^2, 2); %// Sum of squares for each row of B D = bsxfun(@plus, nA, nB.') - 2*newpoints*x.'; %// Compute distance matrix D = sqrt(D); %// Compute square root to complete calculation %// Sort the distances [d, ind] = sort(D, 2); %// Get the indices of the closest distances ind_closest = ind(:, 1:k); %// Also get the nearest points x_closest = permute(reshape(x(ind_closest(:), :).', M, k, []), [2 1 3]); 的两个输入版本独立地对每一行进行排序,因此sort将包含每行的ID,ind将包含相应的距离。然后我们通过简单地将此矩阵截断为d列来确定哪些索引最接近每个查询点。然后,我们使用permutereshape来确定相关的最近点是什么。我们首先使用所有最接近的索引并创建一个点矩阵,将所有ID叠加在一起,这样我们得到一个k矩阵。使用Q * k x Mreshape可以让我们创建3D矩阵,使其成为我们指定的permute矩阵。如果你想自己获得实际距离,我们可以索引到k x M x Q并抓住我们需要的东西。为此,您需要使用sub2ind来获取线性索引,以便我们可以一次性索引到dd的值已经为我们提供了我们需要访问的列。我们需要访问的行只有1 ind_closest次,2 k次,等等,直至kQ代表我们想要返回的点数:

    k

    当我们为上述查询点运行上述代码时,这些是我们得到的索引,点和距离:

    row_indices = repmat((1:Q).', 1, k);
    linear_ind = sub2ind(size(d), row_indices, ind_closest);
    dist_sorted = D(linear_ind);
    

    要与>> ind_closest ind_closest = 120 134 53 73 84 77 78 51 64 87 123 119 118 106 132 108 131 136 126 110 107 62 86 122 71 127 139 115 60 52 99 65 58 94 60 61 80 44 54 72 >> x_closest x_closest(:,:,1) = 5.0000 1.5000 6.7000 2.0000 4.5000 1.7000 3.0000 1.1000 5.1000 1.5000 6.9000 2.3000 4.2000 1.5000 3.6000 1.3000 4.9000 1.5000 6.7000 2.2000 x_closest(:,:,2) = 4.5000 1.6000 3.3000 1.0000 4.9000 1.5000 6.6000 2.1000 4.9000 2.0000 3.3000 1.0000 5.1000 1.6000 6.4000 2.0000 4.8000 1.8000 3.9000 1.4000 x_closest(:,:,3) = 4.8000 1.4000 6.3000 1.8000 4.8000 1.8000 3.5000 1.0000 5.0000 1.7000 6.1000 1.9000 4.8000 1.8000 3.5000 1.0000 4.7000 1.4000 6.1000 2.3000 x_closest(:,:,4) = 5.1000 2.4000 1.6000 0.6000 4.7000 1.4000 6.0000 1.8000 3.9000 1.4000 4.0000 1.3000 4.7000 1.5000 6.1000 2.5000 4.5000 1.5000 4.0000 1.3000 >> dist_sorted dist_sorted = 0.0500 0.1118 0.1118 0.1118 0.1803 0.2062 0.2500 0.3041 0.3041 0.3041 0.3000 0.3162 0.3606 0.4123 0.6000 0.7280 0.9055 0.9487 1.0198 1.0296 0.9434 1.0198 1.0296 1.0296 1.0630 1.0630 1.0630 1.1045 1.1045 1.1180 2.6000 2.7203 2.8178 2.8178 2.8320 2.9155 2.9155 2.9275 2.9732 2.9732 进行比较,您需要为第二个参数指定一个点矩阵,其中每一行都是一个查询点,您将看到此实现与{{1}之间的索引和已排序距离匹配}}


    希望这会对你有所帮助。祝你好运!