如何在matplotlip中注释热图的子图

时间:2019-12-23 22:49:14

标签: python matplotlib

我正在使用自定义定义的函数在matplotlib中绘制热图,并想添加注释 网格中的每个框。基本上每个网格应显示3,4的注释。我认为解决方案并不简单,因此在询问这里之前,我已经进行了一些研究

按照这些帖子中的说明进行操作 https://github.com/matplotlib/matplotlib/issues/10956 how to annotate heatmap with text in matplotlib? https://matplotlib.org/3.1.1/gallery/images_contours_and_fields/image_annotated_heatmap.html

我试图解决问题,但似乎我遗漏了一些东西。

然后我在下面的facet函数下添加了一个for循环,但它没有输出所需的图并抛出KeyError:(0,0)

import pandas as pd
import numpy as np
import itertools
import seaborn as sns
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt

def expandgrid(*itrs):
   product = list(itertools.product(*itrs))
   return {'Var{}'.format(i+1):[x[i] for x in product] for i in range(len(itrs))}

methods=["m" + str(i) for i in range(1,3)]
labels=["l" + str(i) for i in range(1,4)]

times = range(0,100,25)
data = pd.DataFrame(expandgrid(methods,labels, times, times))
data.columns = ['method','labels','dtsi','rtsi']
#data['nw_score'] = np.random.sample(data.shape[0])
data['nw_score'] = np.random.choice([3,4],data.shape[0])

data.head()

method labels  dtsi  rtsi  nw_score
0     m1     l1     0     0         4
1     m1     l1     0    25         4
2     m1     l1     0    50         4
3     m1     l1     0    75         3
4     m1     l1    25     0         3

cmap=ListedColormap(['red', 'blue'])

def facet(data, ax):
    data = data.pivot(index="dtsi", columns='rtsi', values='nw_score')
    ax.imshow(data, cmap=cmap,extent=[0,100,0,100])
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            ax.text(j, i, "{:.2f}".format(data[i,j]), ha="center", va="center")



def myfacetgrid(data, row, col, figure=None):
    rows = np.unique(data[row].values)  
    cols = np.unique(data[col].values)

    fig, axs = plt.subplots(3, 2, 
                            figsize=(2*len(cols)+1, 2*len(rows)+1))

    for i, r in enumerate(rows):
        row_data = data[data[row] == r]
        for j, c in enumerate(cols):
            this_data = row_data[row_data[col] == c]
            facet(this_data, axs[i,j])
    return fig, axs

with sns.plotting_context(font_scale=5.5):
    fig, axs = myfacetgrid(data, row="labels", col="method")


    for ax,method in zip(axs[0,:],data.method.unique()):
        ax.set_title(method, fontweight='bold', fontsize=12)
    for ax,label in zip(axs[:,0],data.labels.unique()):
        ax.set_ylabel(label, fontweight='bold', fontsize=12, rotation=0, ha='right', va='center')
        #fig.suptitle(lt, fontweight='bold', fontsize=12)
    fig.tight_layout()
    fig.subplots_adjust(top=0.8) # make some room for the title

enter image description here

facet函数中添加for循环之后,一切似乎都崩溃了。

1 个答案:

答案 0 :(得分:1)

我尝试对facet函数进行以下更改:

  • data.pivot的结果转换为np.array,以便其可寻址为data[i,j]
  • extent=[0,100,0,100]排除在ax.imshow之外,因为该范围会更改坐标系并阻止使用ij来定位文本。

这似乎可行。

def facet(data, ax):
    data = np.array(data.pivot(index="dtsi", columns='rtsi', values='nw_score'))
    ax.imshow(data, cmap=cmap)
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            ax.text(j, i, "{:.2f}".format(data[i,j]), ha="center", va="center")