假设我有一个像这样的数据集:
import numpy as np
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
X,y = make_blobs(random_state=101) # My data
palette = sns.color_palette('bright',3)
sns.scatterplot(X[:,0], X[:,1],palette=palette,hue=y) # Visualizing the data
我想选择靠近群集中心的数据。假设我想从cluster '0'
中选择靠近中心的数据,我目前正在这样做:
label_0 = X[y==0] # Want to select data from the label '0'
data_index = 2 # Manaully pick the point
sns.scatterplot(X[:,0], X[:,1],palette=palette,hue=y)
plt.scatter(label_0[data_index][0],label_0[data_index][1],marker='*')
由于它不靠近中心,因此我更改了索引并选择了另一个。
data_index = 4
sns.scatterplot(X[:,0], X[:,1],palette=palette,hue=y)
plt.scatter(label_0[data_index][0],label_0[data_index][1],marker='*')
现在关闭了。但是我想知道是否有更有效的方法来实现这一目标?对于像这样的小型数据集,它是可管理的,但是如果我的数据集具有数千个点,则我认为此方法不再可行。
答案 0 :(得分:0)
一种方法是使用K-means algorithm。 这将帮助您找到每个群集的中心。
鉴于您的数据集,步骤如下:
1)查找簇数
num_clusters=len(np.unique(y)) #here 3
2)在您的数据上应用scikit's k-means clustering
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(X)
3)找到每个聚类的中心
centers=kmeans.cluster_centers_ # gives the centers of each cluster
# array([[ 0.26542862, 1.85466779],
# [-9.50316411, -6.52747391],
# [ 3.64354311, 6.62683956]])
4)由于这些中心可能不是您原始数据的一部分,因此我们需要找到最接近它们的点
from scipy import spatial
def nearest_point(array,query):
return array[spatial.KDTree(array).query(query)[1]]
nearest_centers=np.array([nearest_point(X,center) for center in centers])
# array([[ 0.19313183, 1.80387958],
# [-9.12488396, -6.32638926],
# [ 3.65986315, 6.69035824]])
5)绘制原始数据和中心
sns.scatterplot(X[:,0], X[:,1],palette=palette,hue=y)
for nc in nearest_centers:
plt.scatter(nc[0],nc[1],marker='*',color='r')
中心用红叉显示: