Python:确定数据点属于哪个群集

时间:2019-08-21 03:56:11

标签: python cluster-analysis

很抱歉,虽然这有点重复,但是我尝试了How do I find which cluster my data belongs to using Python?中提供的解决方案,但是对我来说不起作用。

我有一个数据框indicators_df,其中第一列是索引,其他三列是浮点数:

# Sample data    
indicators_df.head(4)

index     1      2       3
  a    52.645  36.167  18.762
  b    34.536  26.772  28.438
  c    46.376  21.784  36.884
  d    33.687  24.979  27.349

在进行K-Means聚类处理之前,我必须对其进行缩放以得到下图:

from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler

indicators_df

# Define each scaler
robust_scaler = RobustScaler()

# Scale data
X_train_robust = robust_scaler.fit_transform(indicators_df)

# Create an instance of kmeans model
kmeans = KMeans(n_clusters = 3, random_state = 0).fit(X_train_robust)

# Define cluster centers
cluster_centers = kmeans.cluster_centers_
C1 = cluster_centers[:, 0]
C2 = cluster_centers[:, 1]
C3 = cluster_centers[:, 2]

# Define figure
fig = plt.figure()
ax = Axes3D(fig)

# Define x, y, and z axis
x = X_train_robust[:,0]
y = X_train_robust[:,1]
z = X_train_robust[:,2]

# Define axis labels
column_names = indicators_df.columns
ax.set_xlabel(column_names[0])
ax.set_ylabel(column_names[1])
ax.set_zlabel(column_names[2])

# Define markers and colors
ax.scatter(x, y, z, c = kmeans.labels_.astype(float), cmap = 'winter', marker = 'o')
ax.scatter(C1, C2, C3, marker = 'x', color = 'red')

# Define title
plt.title("Visualization of clustered data with {} clusters".format(cluster), fontweight = 'bold')

plt.show()

enter image description here

现在,我想将其链接回indicators_df,以便有一个名为“ ClusterID”的新列,以便indicator_df看起来像这样:

index     1      2       3      ClusterID
  a    52.645  36.167  18.762      1
  b    34.536  26.772  28.438      2
  c    46.376  21.784  36.884      3
  d    33.687  24.979  27.349      2

我尝试了X_train_robust["cluster"] = kmeans,但这没用。

我需要做什么?

谢谢。

1 个答案:

答案 0 :(得分:0)

标签存储在kmeans.labels_字段中。

请勿使用整个对象kmeans。这些细节在编程中很重要。您链接的示例使用fit_predict而不是fit-相似,但是返回labels_