我正在尝试对熊猫条形图进行分层标记。我想要实现的是对How to add group labels for bar charts in matplotlib?中@Stein的解决方案(第二解决方案)的扭曲。但是,我不想画垂直线来显示分组,而是想使用水平线。
垂直线(原始解)
水平线(所需的解决方案)
@Stein的代码
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
from itertools import groupby
def add_line(ax, xpos, ypos):
line = plt.Line2D([xpos, xpos], [ypos + .1, ypos],
transform=ax.transAxes, color='darkslategrey')
line.set_clip_on(False)
ax.add_line(line)
def label_len(my_index,level):
labels = my_index.get_level_values(level)
return [(k, sum(1 for i in g)) for k,g in groupby(labels)]
def label_group_bar_table(ax, df):
ypos = -.1
scale = 1./df.index.size
for level in range(df.index.nlevels)[::-1]:
pos = 0
for label, rpos in label_len(df.index,level):
lxpos = (pos + .5 * rpos)*scale
ax.text(lxpos, ypos, label, ha='center', transform=ax.transAxes, rotation = 90)
add_line(ax, pos*scale, ypos)
pos += rpos
add_line(ax, pos*scale , ypos)
ypos -= .1
df = test_table().groupby(['Room','Shelf','Staple']).sum()
fig = plt.figure()
ax = fig.add_subplot(111)
df.plot(kind='bar',stacked=False,ax=fig.gca())
#Below 3 lines remove default labels
labels = ['' for item in ax.get_xticklabels()]
ax.set_xticklabels(labels)
ax.set_xlabel('')
label_group_bar_table(ax, df)
fig.subplots_adjust(bottom=.1*df.index.nlevels)
plt.show()
答案 0 :(得分:0)
编辑@Stein的水平线代码。并不是很干净和通用,但是可以。
def add_lines(line_list, ax):
for first, second in zip(line_list, line_list[1:]):
print (first, second)
line = plt.Line2D([first[0] + .01, second[0] - .01], [first[1] - .6, second[1] - .6],
transform=ax.transAxes, color='slategrey')
line.set_clip_on(False)
ax.add_line(line)
def label_len(my_index,level):
labels = my_index.get_level_values(level)
return [(k, sum(1 for i in g)) for k,g in groupby(labels)]
def label_group_bar_table(ax, df):
ypos = -.03
scale = 1./df.index.size
for level in range(df.index.nlevels)[::-1]:
pos = 0
line_list = []
if (level) == (df.index.nlevels - 1):
for label, rpos in label_len(df.index,level):
lxpos = (pos + .5 * rpos)*scale
ax.text(lxpos, ypos, label, ha='center', transform=ax.transAxes, rotation = 90, fontsize = 1.75)
line_list.append((pos*scale, ypos))
pos += rpos
else:
for label, rpos in label_len(df.index,level):
lxpos = (pos + .5 * rpos)*scale
ax.text(lxpos, ypos - .6, label, ha='center', transform=ax.transAxes, rotation = 180, fontsize = 1.75)
line_list.append((pos*scale, ypos))
pos += rpos
line_list.append((pos*scale , ypos))
if level == 0:
add_lines(line_list, ax)
ypos -= .2