我正在尝试使用不同的降维技术将时尚MNIST数据集可视化,我还想将图例附加到具有所谓的real_labels的结果图片中,该标签可以告诉标签的真实名称。对于时尚MNIST来说,真正的标签是:
real_labels = ['t-shirt','trouser','pullover','dress','coat','sandal','shirt','sneaker','bag','ankle boot']
我正在执行以下功能:
def Draw_datasamples_to_figure(X_scaled, labels, axis):
y = ['${}$'.format(i) for i in labels]
num_cls = len(list(set(labels)))
for (X_plot, Y_plot, y1, label1) in zip(X_scaled[:,0], X_scaled[:,1], y, labels):
axis.scatter(X_plot, Y_plot, color=cm.gnuplot(int(label1)/num_cls),label=y1, marker=y1, s=60)
,其中X_scaled指示x和y坐标,标签是类信息的整数(0-9),而轴则指示将在哪个子图窗口中绘制图片。
使用以下命令绘制图例:
ax3.legend(real_labels, loc='center left', bbox_to_anchor=(1, 0.5))
在将图例绘制成图画之前,一切似乎都运作良好。从下面的图片中可以看到,不是数字从0到9,而是在图例中选择的数字是任意的。
我知道问题很可能是分散的,我应该以另一种方式实现它,但是我希望我仍然想念一些简单的东西来解决我的实现问题。我也不想在代码中定义标记和名称的地方使用手工制作的图例,因为我还有其他具有不同类和真实标签名称的数据集。预先感谢!