有效选择靠近群集中心的数据点

时间:2018-12-05 03:36:03

标签: python-3.x seaborn data-science

假设我有一个像这样的数据集:

import numpy as np
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline

X,y = make_blobs(random_state=101) # My data

palette = sns.color_palette('bright',3)
sns.scatterplot(X[:,0], X[:,1],palette=palette,hue=y) # Visualizing the data

enter image description here

我想选择靠近群集中心的数据。假设我想从cluster '0'中选择靠近中心的数据,我目前正在这样做:

label_0 = X[y==0] # Want to select data from the label '0'

data_index = 2 # Manaully pick the point
sns.scatterplot(X[:,0], X[:,1],palette=palette,hue=y)
plt.scatter(label_0[data_index][0],label_0[data_index][1],marker='*')

enter image description here

由于它不靠近中心,因此我更改了索引并选择了另一个。

data_index = 4
sns.scatterplot(X[:,0], X[:,1],palette=palette,hue=y)
plt.scatter(label_0[data_index][0],label_0[data_index][1],marker='*')

现在关闭了。但是我想知道是否有更有效的方法来实现这一目标?对于像这样的小型数据集,它是可管理的,但是如果我的数据集具有数千个点,则我认为此方法不再可行。 enter image description here

1 个答案:

答案 0 :(得分:0)

一种方法是使用K-means algorithm。 这将帮助您找到每个群集的中心。

鉴于您的数据集,步骤如下:

1)查找簇数

num_clusters=len(np.unique(y)) #here 3

2)在您的数据上应用scikit's k-means clustering

from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(X)

3)找到每个聚类的中心

centers=kmeans.cluster_centers_ # gives the centers of each cluster
# array([[ 0.26542862,  1.85466779],
#        [-9.50316411, -6.52747391],
#        [ 3.64354311,  6.62683956]])

4)由于这些中心可能不是您原始数据的一部分,因此我们需要找到最接近它们的点

from scipy import spatial

def nearest_point(array,query):
    return array[spatial.KDTree(array).query(query)[1]]

nearest_centers=np.array([nearest_point(X,center) for center in centers])
# array([[ 0.19313183,  1.80387958],
#       [-9.12488396, -6.32638926],
#       [ 3.65986315,  6.69035824]])

5)绘制原始数据和中心

sns.scatterplot(X[:,0], X[:,1],palette=palette,hue=y) 
for nc in nearest_centers:
    plt.scatter(nc[0],nc[1],marker='*',color='r')

中心用红叉显示:

The centers are shows by the red crosses