我有两个数据集(data.csv,label.csv)。“ label.csv”数据集包含“ data.csv”数据集上样本的基因名称。我想在这些数据集中运行k-means聚类算法并绘制散点图。但是我遇到了困难。下面,我给出了为完成聚类而执行的代码。
import pandas as pd
from sklearn.cluster import KMeans
X = pd.read_csv("data.csv")
Y = pd.read_csv("labels.csv")
#reduce the dimension of X
import time
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
X = X.dropna()
# drop the first column which only contains strings
X = X.drop(X.columns[X.columns.str.contains('unnamed', case=False)], axis=1)
# label encode the multiple class string into integer values
Y = Y.drop(Y.columns[0], axis=1)
le = LabelEncoder()
le.fit(Y)
class_names = list(le.classes_)
Y = Y.apply(LabelEncoder().fit_transform)
Y_data = Y.values.flatten()
# use TSNE to visualize the high dimension data in 2D
t0 = time.time()
tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=300, random_state=100)
tsne_results = tsne.fit_transform(X)
t1 = time.time()
print("TSNE took at %.2f seconds" % (t1 - t0))
# visualize TSNE
x_axis = tsne_results[:,0]
y_axis = tsne_results[:,1]
plt.scatter(x_axis, y_axis, c=Y_data, cmap=plt.cm.get_cmap("jet", 100))
plt.colorbar(ticks=range(10))
plt.clim(-0.5, 9.5)
plt.title("TSNE Visualization")
plt.show()
上面的代码给出了我的数据集中5个不同类的散点图(如下所示)。这些类别的颜色从0到4(5个类别)不同。
但是当我应用K-means聚类代码(如下所示)时,它显示5个聚类,但类的颜色不是从0 0到4而是从0到9。散布图图像在k均值聚类下面给出代码。
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
model=KMeans(n_clusters=5)
model.fit(tsne_results)
label=model.predict(tsne_results)
#centroid calculation
xs =tsne_results[:,0]
ys =tsne_results[:,1]
plt.scatter(xs,ys,c=Y_data,alpha=0.5)
centroids = model.cluster_centers_
centroids_x = centroids[:,0]
centroids_y = centroids[:,1]
plt.scatter(centroids_x, centroids_y, marker='D', s=50)
plt.colorbar(ticks=range(10))
plt.clim(-0.5, 9.5)
plt.show()
现在,我需要在代码的哪一部分进行更改,以便获得散点图,例如第一个散点图,其中显示了5个簇从0到4的彩色类。
答案 0 :(得分:-1)
我看不出这里是什么问题。您有5个不同的簇以不同的颜色进行着色,我认为这只是0-9范围内的着色,质心是它们自己的颜色。因此,除非出于某些原因特别需要它们在0-4范围内,否则我将使用此图。您可以看到集群和质心,而这正是您真正需要的。
虽然要看的代码是
plt.colorbar(ticks=range(10))
plt.clim(-0.5, 9.5)