K-Means聚类从Python开始的Scratch

时间:2017-12-25 03:51:29

标签: python numpy k-means

我一直在尝试在python / numpy中从头开始实现一个简单的k-means聚类算法。最初我使用一个大小为[1000,2]的随机数组作为“数据集”,所以我可以轻松地绘制它,我的代码似乎正在工作(将点分成k个部分,将质心放在每个部分的中心) 。现在我想在具有实际组的数据集上测试它,所以我连接了两个较小的数组。当k = 2时,它按预期工作。但是,较大的k值不起作用;由于某种原因,每个簇中的点数似乎为0,并且它返回空切片警告的平均值,在第42行的true_division中遇到无效值(在np.mean()中除以0得到NaN)。我试着看到质心在哪里,它没有为某些星团分配任何点,而绘制它们只在(0.5,0.5)显示1个质心(而不是k质心)。为什么其他质心没有显示,为什么有些质心没有它们的簇呢?

我已经尝试在数据集中的随机点初始化质心(而不是完全随机地初始化它们),并且它仍然经常(它已经工作了几次)导致相同的错误。

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

w1 = np.random.uniform(0.8,1.0,(50,2))
w2 = np.random.uniform(0.0,0.2,(50,2))
X = np.concatenate((w1,w2),axis=0)

def dist(a,b):
    a = np.reshape(a,(1,np.shape(a)[0],np.shape(a)[1]))
    b = np.reshape(b,(np.shape(b)[0],1,np.shape(b)[1]))

    dist = np.sqrt(np.sum(np.square(a-b), axis=2))

    a = np.reshape(a,(np.shape(a)[1],np.shape(a)[2]))
    b = np.reshape(b,(np.shape(b)[0],np.shape(b)[2]))
    return dist

def k_means(k, X):
    centroids = np.random.rand(k,np.shape(X)[1])

    error = 1
    while error != 0:
        prevcen = centroids

        cluster = np.argmin(dist(X,centroids), axis=0)

        clusters = {}
        points = {}

        for i in range(k):
            point = np.empty((0,np.shape(X)[1]))
            for j in range(np.shape(X)[0]):
                if cluster[j] == i:
                    concatenation = np.reshape(X[j,:],(1,np.shape(X)[1]))
                    point = np.concatenate((point, concatenation), axis=0)
            points = {i: point}
            clusters.update(points)

        for i in range(k):
            points = clusters[i]
            centroids[i,:] = np.mean(points, axis=0)

        error = np.linalg.norm((prevcen - centroids))

    return clusters, centroids

    k = 4
    clusters, centroids = k_means(k,X)

    colors = ["red", "green", "blue", "beige", "yellow", "magenta","purple", "pink", "cyan", "gray"]

    plt.scatter(centroids[:,0], centroids[:,1], s=20, c="black")

    for i in range(k):
        cluster = clusters[i]
        plt.scatter(cluster[:,0], cluster[:,1], s=5, c=colors[i])

    plt.show()

0 个答案:

没有答案