我正在使用sklearn根据一些零售数据进行kmeans集群。
我们正在幕后使用此集群来细分客户(例如,蓝色客户很棒,绿色客户有这样的需求,等等)。
所显示的数据点是根据客户放入的4个群集中的哪一个而不同的颜色。但是我找不到一种直接推断哪种颜色是哪个段号(或如何强制某些段号成为某种颜色)的方法。
散点图中的 c=y
是使用y值(即观察的预测段)选择颜色的地方。有4个细分市场。我只是不知道这4种颜色中的哪一种被映射为哪种颜色!
有人可以建议我如何添加图例或自己强制颜色吗?
kmeans=kmeans.fit(X)
y=kmeans.predict(X)
fig = plt.figure()
ax = Axes3D(fig)
ax.view_init(30)
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=y)
ax.set_xlabel(vx)
ax.set_ylabel(vy)
ax.set_zlabel(vz)
答案 0 :(得分:1)
您需要从散点图返回一个句柄并绘制颜色栏,
cm = ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=y)
plt.colorbar(cm)
作为最小示例
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
X = np.random.randn(100,3)
y = np.random.randn(100)
fig = plt.figure()
ax = Axes3D(fig)
ax.view_init(30)
cm = ax.scatter(X[:, 0], X[:, 1], X[:, 2], c=y)
plt.colorbar(cm)
plt.show()