我刚刚将matplotlib升级到版本3.1.1,并且正在尝试使用legend_elements。
我正在对30,000个展平的灰度图像的数据集绘制来自PCA的前两个成分的散点图。每个图像都被标记为四个主要类别之一(附件,服装,鞋类,个人护理)。我通过创建一个颜色列(值介于0到3之间)来按“主类别”对绘图进行颜色编码。
我已经阅读了PathCollection.legend_elements的文档,但尚未成功合并'func'或'fmt'参数。 https://matplotlib.org/3.1.1/api/collections_api.html#matplotlib.collections.PathCollection.legend_elements
此外,我尝试遵循提供的示例: https://matplotlib.org/3.1.1/gallery/lines_bars_and_markers/scatter_with_legend.html
### create column for color codes
masterCat_codes = {'Accessories':0,'Apparel':1, 'Footwear':2, 'Personal Care':3}
df['colors'] = df['masterCategory'].apply(lambda x: masterCat_codes[x])
### create scatter plot
fig, ax = plt.subplots(figsize=(8,8))
scatter = ax.scatter( *full_pca.T, s=.1 , c=df['colors'], label= df['masterCategory'], cmap='viridis')
### using legend_elements
legend1 = ax.legend(*scatter.legend_elements(num=[0,1,2,3]), loc="upper left", title="Category Codes")
ax.add_artist(legend1)
plt.show()
生成的图例标签为0、1、2、3。(无论在定义“散布”时是否指定label = df ['masterCategory']都会发生)。我想标签说配件,服装,鞋类,个人护理。
有没有一种方法可以通过legend_elements完成?
注意:由于数据集很大且预处理的计算量很大,因此我编写了一个更易于重现的示例:
fake_data = np.array([[1,1],[1,2],[1,3],[2,1],[2,2],[2,3],[3,1],[3,2],[3,3]])
fake_df = pd.DataFrame(fake_data, columns=['X', 'Y'])
groups = np.array(['A', 'A', 'A', 'B', 'B', 'B', 'C', 'C', 'C'])
fake_df['Group'] = groups
group_codes = {k:idx for idx, k in enumerate(fake_df.Group.unique())}
fake_df['colors'] = fake_df['Group'].apply(lambda x: group_codes[x])
fig, ax = plt.subplots()
scatter = ax.scatter(fake_data[:,0], fake_data[:,1], c=fake_df['colors'])
legend = ax.legend(*scatter.legend_elements(num=[0,1,2]), loc="upper left", title="Group \nCodes")
ax.add_artist(legend)
plt.show()
答案 0 :(得分:1)
解决方案 感谢EmportanceOfBeingErnest
group_codes = {k:idx for idx, k in enumerate(fake_df.Group.unique())}
fake_df['colors'] = fake_df['Group'].apply(lambda x: group_codes[x])
fig, ax = plt.subplots(figsize=(8,8))
ax.scatter(fake_data[:,0], fake_data[:,1], c=fake_df['colors'])
ax.legend(handles=scatter.legend_elements(num=[0,1,2,3])[0], labels=group_codes.keys())
plt.show()