matplotlib图例中的子标题,带有已定义的类别子集和堆积图

时间:2017-06-15 22:59:10

标签: python pandas matplotlib plot

pandas数据框具有以下形式,其中Year是索引:

      A:Cat1  A:Cat2  B:Cat1  B:Cat2  B:Cat3
Year                                        
1977     0.5    0.25    0.15     0.1     0.1
1981     0.2     NaN    0.40     0.1     0.2
1983     0.1    0.10    0.30     0.2     0.3

重要的是你在两个不同的“超类别”A和B中有相同的类别Cat1和Cat2。为了绘制所有类别的变化,我使用堆叠图并使用两组不同的颜色对于每个超类别。所有这些颜色都保存在列表colors中。

我现在正在做的是绘制图表(pltpyplot):

plt.stackplot(data.index.values,data.fillna(0).T.values,colors=colors,labels=data.columns.values)
plt.legend(loc="best")

这给出了以下数据:

Result of previous code

现在,我想要做的是避免在图例中重复超类别A和B,方法是为每个超类别创建两个不同的图例,或者在同一图例中包含副标题。我查看了关于副标题的this other question,但重点是我希望能够指定图例的两列之间的断点,因此仅指定ncol=2不起作用,因为它不起作用因为我在每个«supercategory»中没有相同数量的类别,所以在正确的点上中断。

1 个答案:

答案 0 :(得分:1)

也许尝试添加占位符来处理每个超类别中不等数量的类别。或者使用水平组标签:

import io

import pandas as pd
import matplotlib.lines as mlines
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt

s = """1977 0.5 0.25    0.15    0.1 0.1
1981    0.2 NaN 0.40    0.1 0.2
1983    0.1 0.10    0.30    0.2 0.3"""
t = [('A', 'Cat1'), 
     ('A', 'Cat2'), 
     ('B', 'Cat1'), 
     ('B', 'Cat2'), 
     ('B', 'Cat3')]
index = pd.MultiIndex.from_tuples(t)
df = pd.read_table(io.StringIO(s), names=index)
df.index.name = 'Year'
colors = ['b', 'c', 'k', 'g', 'w']
plt.stackplot(df.index.values,df.fillna(0).T.values,colors=colors)

ha = mlines.Line2D([], [], marker='None', linestyle='None')
hb = mlines.Line2D([], [], marker='None', linestyle='None')
ha1 = mpatches.Patch(color=colors[0], ec='k')
ha2 = mpatches.Patch(color=colors[1], ec='k')
hb1 = mpatches.Patch(color=colors[2], ec='k')
hb2 = mpatches.Patch(color=colors[3], ec='k')
hb3 = mpatches.Patch(color=colors[4], ec='k')
hblank = mpatches.Patch(visible=False)
l1 = plt.legend([ha, ha1, ha2, hblank, hb, hb1, hb2, hb3], 
                ['A', 'Cat1', 'Cat2', '', 'B', 'Cat1', 'Cat2', 'Cat3'], 
                loc=2, ncol=2) # Two columns, vertical group labels
l2 = plt.legend([ha, hblank, hb, hblank, hblank, ha1, ha2, hb1, hb2, hb3], 
                ['A', '', 'B', '', '', 'Cat1', 'Cat2', 'Cat1', 'Cat2', 'Cat3'], 
                loc=4, ncol=2) # Two columns, horizontal group labels

ax = plt.gca()
ax.add_artist(l1)
ax.get_xaxis().get_major_formatter().set_useOffset(False)
plt.show()

enter image description here