假设我有一个包含3个类的数据,下面的代码可以为我提供一个带有正确图例的完美图形,其中按类绘制数据。
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import make_blobs
import numpy as np
X, y = make_blobs()
X0 = X[y==0]
X1 = X[y==1]
X2 = X[y==2]
ax = plt.subplot(1,1,1)
ax.scatter(X0[:,0],X0[:,1], lw=0, s=40)
ax.scatter(X1[:,0],X1[:,1], lw=0, s=40)
ax.scatter(X2[:,0],X2[:,1], lw=0, s=40)
ax.legend(['0','1','2'])
但是,如果我有3000个类的数据集,则上述方法不再起作用。 (您不会期望我写每个类对应的3000行,对吧?) 因此,我提出了以下绘图代码。
num_classes = len(set(y))
palette = np.array(sns.color_palette("hls", num_classes))
ax = plt.subplot(1,1,1)
ax.scatter(X[:,0], X[:,1], lw=0, s=40, c=palette[y.astype(np.int)])
ax.legend(['0','1','2'])
这段代码很完美,我们只用1行就可以绘制所有类。但是,图例这次没有正确显示。
在使用以下方法绘制图形时,如何保持正确的图例?
ax.scatter(X[:,0], X[:,1], lw=0, s=40, c=palette[y.astype(np.int)])
答案 0 :(得分:2)
plt.legend()
在剧情中有多个“艺术家”时效果最佳。在您的第一个示例中就是这种情况,这就是为什么轻松调用plt.legend(labels)
的原因。
如果您担心编写大量代码行,则可以利用for
循环。
在使用5个类的示例中可以看到:
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
import numpy as np
X, y = make_blobs(centers=5)
ax = plt.subplot(1,1,1)
for c in np.unique(y):
ax.scatter(X[y==c,0],X[y==c,1],label=c)
ax.legend()
np.unique()
返回y的唯一元素的排序数组,方法是循环遍历这些元素,并用自己的艺术家plt.legend()
绘制每个类,可以轻松地提供图例。
您还可以在制作图时为其分配标签,这可能更安全。
plt.scatter(..., label=c)
后跟plt.legend()
答案 1 :(得分:0)
为什么不简单地执行以下操作?
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import make_blobs
import numpy as np
X, y = make_blobs()
ngroups = 3
ax = plt.subplot(1, 1, 1)
for i in range(ngroups):
ax.scatter(X[y==i][:,0], X[y==i][:,1], lw=0, s=40, label=i)
ax.legend()