matplotlib热图与分隔列

时间:2018-01-07 19:55:13

标签: python matplotlib heatmap

我有一个包含多列的热图,如下所示:

Example Heatmap plot

这几乎就是我想要的。我使用imshow来做它,绘图代码非常简单,数据作为2D numpy数组:

plt.imshow(data, cmap="hot", 
               vmin=0.0, vmax=1.0, aspect='auto',  
               interpolation='nearest')
plt.colorbar()
plt.show()

但理想情况下,此头像中的列应该分组(或分解),因为它们代表相关的事物。有没有简单的方法来创建这个图的版本,我可以说,将0-4,5-10和11-22列分隔成单独的块,而不重复颜色条或yaxis标签,但可能有唯一的标签对于每组列?

理想情况下,我想要一个看起来像(在ascii art中)的情节:

0   +---+   +-------+  +-------------+   +-+ 1.0
    |   |   |       |  |             |   | | 
500 |   |   |       |  |             |   | |
    |   |   |       |  |             |   | |
1000+---+   +-------+  +-------------+   +-+ 0.0
     L1      Label2       Label3           

有什么想法吗?

1 个答案:

答案 0 :(得分:2)

您可以轻松切片data数组SameValueNonNumber

之后,只需使用正确的几何体创建轴即可。 using standard numpy array notation为此,或者更简单的You could use Gridspec版本。

data = np.random.random(size=(1000,22))

fig, axs = plt.subplots(1,3,sharey=True,gridspec_kw={'width_ratios':[5,6,12]})
a1 = axs[0].imshow(data[:,:5], cmap="hot", 
               vmin=0.0, vmax=1.0, aspect='auto',  
               interpolation='nearest')
a2 = axs[1].imshow(data[:,5:10], cmap="hot", 
               vmin=0.0, vmax=1.0, aspect='auto',  
               interpolation='nearest')
a3 = axs[2].imshow(data[:,11:], cmap="hot", 
               vmin=0.0, vmax=1.0, aspect='auto',  
               interpolation='nearest')

for ax,l in zip(axs,['Label 1','Label 2','Label 3']):
    ax.set_xticklabels([])
    ax.set_xlabel(l)

plt.colorbar(a1)

plt.show()

plt.subplots()