Python/Matplotlib 子图 - 堆积条形图 - 为类别设置固定颜色

时间:2021-03-13 06:29:42

标签: python pandas matplotlib colors subplot

我想从 pandas df(使用 df.pivot_table)创建堆叠条形图的子图,并保持子图的类别颜色一致(即固定)。

问题在于,并非数据透视表中的每个索引值(示例 df 中的“域”)都具有相同数量的类别 - 因此 matplotlib 重新启动每个子图中的类别着色 - 导致相同的颜色用于两个不同的类别。

这是用于说明的虚拟代码:

df:

    main domain category  val
0   cat1      a    apple    1
1   cat1      a   orange    1
2   cat1      a  broccli    1
3   cat1      b    apple    1
4   cat1      b   orange    1
5   cat1      a     plum    1
6   cat1      c    apple    1
7   cat2      b   orange    1
8   cat2      b   orange    1
9   cat2      b    apple    1
10  cat2      b   orange    1
11  cat2      c     plum    1
12  cat2      c    apple    1
13  cat2      b   orange    1
14  cat2      b   orange    1

代码是:

import pandas as pd
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(15, 10))
sub_plot_grid = (4, 10)
sub_plot_col_size = 5
sub_plot_row_size = 3

ax1 = plt.subplot2grid(sub_plot_grid, (0, 0), colspan=sub_plot_col_size, rowspan=sub_plot_row_size)
ax2 = plt.subplot2grid(sub_plot_grid, (0, 5), colspan=sub_plot_col_size, rowspan=sub_plot_row_size, sharey=ax1, sharex=ax1)
ax2.tick_params(axis='y', which='both', length=0)
list_of_ax = [ax1, ax2]

sub_df_1 = df[df['main'] == 'cat1']
sub_df_2 = df[df['main'] == 'cat2']
list_of_df = [sub_df_1, sub_df_2]

for ax, df in zip(list_of_ax, list_of_df):
    df1 = df.pivot_table(index='domain', columns='category', values='val', aggfunc='sum', dropna=False,  margins=True).sort_values(by='All', ascending=False).drop('All').drop('All', axis=1)
    df1.drop(df1.loc[df1.sum(axis=1) == 0].index, inplace=True)
    df1.drop(columns=df1.columns[df1.sum() == 0], inplace=True)
    df1.plot(kind='barh', stacked=True, alpha=0.7, ax=ax)

plt.show()

剧情是:

enter image description here

问题在“orange”和“plum”类别中很明显:在第一个子图中-“orange”类别的颜色在第二个子图中为绿色,颜色为橙色。第一个子图中的“梅花”类别颜色为红色,第二个子图中的颜色为绿色。

我需要每个类别的颜色对于所有子图都保持相同。

我已经搜索了一段时间的解决方案并尝试了一些不同的方法,包括尝试手动传递颜色列表或使用颜色图,但是 matplotlib 重新启动每个子图的颜色的问题仍然存在。

任何帮助将不胜感激。

1 个答案:

答案 0 :(得分:0)

请参阅文档以了解 bar charts with pandascolor 参数。具体来说,关于传递 dict 的部分:

<块引用>

{column namecolor} 形式的字典,这样每一列都将是 相应地着色。例如,如果您的列名为 a 和 b,则传递 {'a': 'green', 'b': 'red'} 将使 a 列的条形变为绿色,而 b 列的条形变为红色。

所以如果你定义:

colors = {'apple':'red', 'orange':'orange', 'plum':'purple', 'broccli':'green'}

然后绘制:

df1.plot(kind='barh', stacked=True, alpha=0.7, ax=ax, color=colors)

您将获得以下信息:

enter image description here