试图了解这个Python函数中发生了什么

时间:2017-11-29 02:29:58

标签: python numpy k-means euclidean-distance

def closest_centroid(points, centroids):
    """returns an array containing the index to the nearest centroid for each point"""
    distances = np.sqrt(((points - centroids[:, np.newaxis])**2).sum(axis=2))
    return np.argmin(distances, axis=0)

有人可以解释这个功能的确切工作吗?我目前看到的是points

31998888119     0.94     34
23423423422     0.45     43
....

等等。在此numpy数组中,points[1]为长ID,而points[2]0.94points[3]为第一个条目34。< / p>

Centroids只是这个特定阵列的随机选择:

def initialize_centroids(points, k):
    """returns k centroids from the initial points"""
    centroids = points.copy()
    np.random.shuffle(centroids)
    return centroids[:k] 

现在我希望从points的值中忽略第一列ID和centroids(再次忽略第一列)获得欧几里德距离。我并不完全理解distances = np.sqrt(((points - centroids[:, np.newaxis])**2).sum(axis=2))行的语法。我们为什么要在第三列中进行求和,同时对新轴进行求解:np.newaxis?我还应该沿着什么轴使np.argmin工作?

1 个答案:

答案 0 :(得分:0)

有助于考虑尺寸。我们假设k=4并且有10个点,所以points.shape = (10,3)

接下来,centroids = initialize_centroids(points, 4)会返回尺寸为(4,3)的对象。

让我们从内部分解这一行:

distances = np.sqrt(((points - centroids[:, np.newaxis])**2).sum(axis=2))

  1. 我们想从每个点减去每个质心。由于pointscentroids是2维的,因此每个points - centroid都是2维的。如果只有1个质心,那么我们没问题。但是我们有4个质心!所以我们需要为每个质心执行points - centroids。因此,我们需要另一个维度来存储它。因此添加了np.newaxis

  2. 我们对它进行调整,因为它是一个距离,因此我们希望将负数转换为正数(也因为我们正在最小化欧几里德距离)。

  3. 我们没有对第三栏进行总结。实际上,我们总结了每个质心点和质心之间的差异。

  4. np.argmin()找到距离最小的质心。因此,对于每个质心,对于每个点,找到最小索引(因此argmin而不是min)。该指数是指定给该点的质心。

  5. 以下是一个例子:

    points = np.array([
    [   1, 2, 4],
    [   1, 1, 3],
    [   1, 6, 2],
    [   6, 2, 3],
    [   7, 2, 3],
    [   1, 9, 6],
    [   6, 9, 1],
    [   3, 8, 6],
    [   10, 9, 6],
    [   0, 2, 0],
    ])
    
    centroids = initialize_centroids(points, 4)
    
    print(centroids)
    array([[10,  9,  6],
       [ 3,  8,  6],
       [ 6,  2,  3],
       [ 1,  1,  3]])
    
    distances = (pts - centroids[:, np.newaxis])**2
    
    print(distances)
    array([[[ 81,  49,   4],
        [ 81,  64,   9],
        [ 81,   9,  16],
        [ 16,  49,   9],
        [  9,  49,   9],
        [ 81,   0,   0],
        [ 16,   0,  25],
        [ 49,   1,   0],
        [  0,   0,   0],
        [100,  49,  36]],
    
       [[  4,  36,   4],
        [  4,  49,   9],
        [  4,   4,  16],
        [  9,  36,   9],
        [ 16,  36,   9],
        [  4,   1,   0],
        [  9,   1,  25],
        [  0,   0,   0],
        [ 49,   1,   0],
        [  9,  36,  36]],
    
       [[ 25,   0,   1],
        [ 25,   1,   0],
        [ 25,  16,   1],
        [  0,   0,   0],
        [  1,   0,   0],
        [ 25,  49,   9],
        [  0,  49,   4],
        [  9,  36,   9],
        [ 16,  49,   9],
        [ 36,   0,   9]],
    
       [[  0,   1,   1],
        [  0,   0,   0],
        [  0,  25,   1],
        [ 25,   1,   0],
        [ 36,   1,   0],
        [  0,  64,   9],
        [ 25,  64,   4],
        [  4,  49,   9],
        [ 81,  64,   9],
        [  1,   1,   9]]])
    
    print(distances.sum(axis=2))
    array([[134, 154, 106,  74,  67,  81,  41,  50,   0, 185],
       [ 44,  62,  24,  54,  61,   5,  35,   0,  50,  81],
       [ 26,  26,  42,   0,   1,  83,  53,  54,  74,  45],
       [  2,   0,  26,  26,  37,  73,  93,  62, 154,  11]])
    
    # The minimum of the first 4 centroids is index 3. The minimum of the second 4 centroids is index 3 again.
    
    print(np.argmin(distances.sum(axis=2), axis=0))
    array([3, 3, 1, 2, 2, 1, 1, 1, 0, 3])