Python中的简单k均值算法

时间:2018-11-27 21:22:12

标签: python k-means

以下是k-means算法的非常简单的实现。

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(0)

DIM = 2
N = 2000
num_cluster = 4
iterations = 3

x = np.random.randn(N, DIM)
y = np.random.randint(0, num_cluster, N)

mean = np.zeros((num_cluster, DIM))
for t in range(iterations):
    for k in range(num_cluster):
        mean[k] = np.mean(x[y==k], axis=0)
    for i in range(N):
        dist = np.sum((mean - x[i])**2, axis=1)
        pred = np.argmin(dist)
        y[i] = pred

for k in range(num_cluster):
    plt.scatter(x[y==k,0], x[y==k,1])
plt.show()

以下是代码产生的两个示例输出:

enter image description here

enter image description here

第一个示例(num_cluster = 4)看起来像预期的那样。第二个示例(num_cluster = 11)仅在群集上显示,这显然不是我想要的。该代码的工作取决于我定义的类数和迭代数。

到目前为止,我在代码中找不到错误。簇以某种方式消失了,但我不知道为什么。

有人看到我的错误吗?

3 个答案:

答案 0 :(得分:2)

您得到一个集群,因为实际上只有一个集群。
您的代码中没有什么可以避免集群消失,而事实是,这也会在4个集群上发生,但需要更多的迭代。
我用4个集群和1000个迭代运行了您的代码,它们全部被一个大的主导集群吞没了。
想一想,您的大型集群已经通过了一个关键点,并且一直在增长,因为其他点正逐渐变得比以前更接近于它。
如果您达到平衡点(或静止点),而群集之间没有任何移动,则不会发生这种情况。但这显然有点罕见,而且您尝试估计的群集越多,这种情况就越少。


一个澄清:当有4个“真实”群集并且您试图估计4个群集时,同样的事情也可能发生。但这意味着相当讨厌的初始化,可以通过智能地聚合多个随机种子运行来避免。
也有一些常见的“技巧”,例如将初始方法分开,或在不同的预先估计的高密度位置的中心等。但是,这已经开始涉及,您应该更深入地了解有关k均值的知识,这个目的。

答案 1 :(得分:1)

似乎有NaN进入了图片。 使用种子= 1,迭代= 2,簇的数量从最初的4个减少到有效的3个。在下一次迭代中,从技术上讲,簇的数量骤降至1个。

有问题的质心的NaN均值坐标将导致产生奇怪的结果。为了排除那些变成空的有问题的集群,一个选项(可能有点太懒了)是将相关坐标设置为Inf,从而使其比游戏中仍然存在的坐标“更远”。 “输入”坐标不能为Inf)。 下面的代码片段是对此的快速说明,以及一些我用来窥视正在发生的事情的调试消息:

[...]
for k in range(num_cluster):
    mean[k] = np.mean(x[y==k], axis=0)
    # print mean[k]
    if any(np.isnan(mean[k])):
        # print "oh no!"
        mean[k] = [np.Inf] * DIM
[...]

通过这种修改,发布的算法似乎可以更稳定的方式工作(即,到目前为止我还不能打破它)。

另请参阅关于分裂意见的评论中也提到的Quora link,以及example here的《统计学习的要素》一书-该算法也没有在其中明确定义相关方面。

答案 2 :(得分:1)

K-均值对初始条件也很敏感。就是说,k均值可以并且将会丢弃聚类(但是降至一个很奇怪)。在您的代码中,您将随机聚类分配给这些点。

这是问题所在:如果我对您的数据进行几个随机子采样,那么它们的均值将大致相同。每次迭代,非常相似的质心会彼此靠近,并且更有可能掉落。

相反,我更改了代码,以选择数据集中的num_cluster个点用作初始质心(方差较大)。这似乎会产生更稳定的结果(在数十次运行中未观察到集群行为下降):

import numpy as np
import matplotlib.pyplot as plt

DIM = 2
N = 2000
num_cluster = 11
iterations = 3

x = np.random.randn(N, DIM)
y = np.zeros(N)
# initialize clusters by picking num_cluster random points
# could improve on this by deliberately choosing most different points
for t in range(iterations):
    if t == 0:
        index_ = np.random.choice(range(N),num_cluster,replace=False)
        mean = x[index_]
    else:
        for k in range(num_cluster):
            mean[k] = np.mean(x[y==k], axis=0)
    for i in range(N):
        dist = np.sum((mean - x[i])**2, axis=1)
        pred = np.argmin(dist)
        y[i] = pred

for k in range(num_cluster):
    fig = plt.scatter(x[y==k,0], x[y==k,1])
plt.show()

Ran a few dozen times with consistent results