我正在尝试获取高斯混合模型中所有群集颜色的图例。数据集很大;我可以在没有适当图例的情况下快速绘制它(图例中仅提供一种颜色):
from sklearn.mixture import GaussianMixture
gmm = GaussianMixture(n_components=9,random_state=15).fit(X)
clust_labels = gmm.predict(X)
plt.figure(figsize=(15,12))
plt.scatter(X.values[:, 0], X.values[:, 1], c=clust_labels, s=40, cmap='viridis')
plt.xlabel('Arrival Time')
plt.xticks(np.arange(0,24,2))
plt.ylabel('Session Duration in Hours')
plt.legend(clust_labels,loc=2)
plt.show()
...但是,如果我尝试将图例中的所有颜色包括在其他一些问题中概述的内容中,则运行循环将花费很长时间。这是我的代码:
from sklearn.mixture import GaussianMixture
gmm = GaussianMixture(n_components=9,random_state=15).fit(X)
clust_labels = gmm.predict(X)
plt.figure(figsize=(15,12))
cmap = plt.get_cmap('viridis')
names = np.unique(clust_labels).tolist()
colors = cmap(np.linspace(0, 1, len(names)))
for label in clust_labels:
x = [X.values[i,0] for i in range(len(clust_labels)) if clust_labels[i]==label]
y = [X.values[i,1] for i in range(len(clust_labels)) if clust_labels[i]==label]
plt.scatter(x, y, color=colors[label], label=clust_labels[label])
plt.xlabel('Arrival Time')
plt.xticks(np.arange(0,24,2))
plt.ylabel('Session Duration in Hours')
plt.legend(loc=2)
plt.show()
似乎必须有一种更简单的方法来完成此操作,而不会导致永恒循环。在图的外部有一张表格将颜色映射到群集,这对我来说甚至可以接受;我只需要一种引用聚类的方式即可在图外进行进一步的工作(即仅处理来自某些聚类的数据)。