如何手动为分类变量的类型分配颜色?

时间:2019-11-25 02:34:26

标签: python matplotlib seaborn

我正在以下代码中创建2个图。我在JobDomain列中将类别值设为

  • Cat1
  • Cat2
  • Cat3

以下代码为上述类别生成2个具有不同颜色的绘图。对于这3个类别,我需要使两个图具有相同的颜色。

colors = ["#F28E2B", "#4E79A7","#79706E"]

edu = (df.groupby(['JobDomain'])['sal']
                         .value_counts(normalize=True)
                         .rename('Percentage')
                         .mul(100)
                         .reset_index()
                         .sort_values('sal'))

coding = (df.groupby(['JobDomain'])['sal2']
                         .value_counts(normalize=True)
                         .rename('Percentage')
                         .mul(100)
                         .reset_index()
                         .sort_values('sal2'))

fig, axs = plt.subplots(ncols=2,figsize=(20, 6),sharey=True)

plt.subplots_adjust(wspace=0.4)

p=sns.barplot(x="sal",y="Percentage",hue="JobDomain",data=edu,
              ax=axs[0],palette=sns.color_palette(colors))
q=sns.barplot(x="sal2",y="Percentage",hue="JobDomain",data=coding,
              ax=axs[1],palette=sns.color_palette(colors))

1 个答案:

答案 0 :(得分:2)

通过创建将每个类别映射到颜色的字典(并将其传递到palette而不调用sns.color_palette)。一个例子:

import seaborn as sns
from pandas import DataFrame
from matplotlib import pyplot as plt

df = DataFrame({'JobDomain': ['Cat1', 'Cat2', 'Cat3', 'Cat1', 'Cat3'],
                'sal':       [  110,     90,    100,    200,    130],
                'sal2':      [  100,    280,    320,    240,    440]
                })

colors = {'Cat1': "#F28E2B", 'Cat2': "#4E79A7", 'Cat3': "#79706E"}

edu = (df.groupby(['JobDomain'])['sal']
                         .value_counts(normalize=True)
                         .rename('Percentage')
                         .mul(100)
                         .reset_index()
                         .sort_values('sal'))
coding = (df.groupby(['JobDomain'])['sal2']
                         .value_counts(normalize=True)
                         .rename('Percentage')
                         .mul(100)
                         .reset_index()
                         .sort_values('sal2'))

fig, axs = plt.subplots(ncols=2,figsize=(20, 6),sharey=True)
plt.subplots_adjust(wspace=0.4)
p = sns.barplot(x="sal",y="Percentage",hue="JobDomain",data=edu,ax=axs[0],palette=colors)
q = sns.barplot(x="sal2",y="Percentage",hue="JobDomain",data=coding,ax=axs[1],palette=colors)

h, l = p.get_legend_handles_labels()
l, h = zip(*sorted(zip(l, h)))
p.legend(h, l, title="Job Domain")
q.legend(h, l, title="Job Domain")

plt.show()

PS:要再次对图例进行排序,请在plt.show()之前插入:

h, l = p.get_legend_handles_labels()
l, h = zip(*sorted(zip(l, h)))
p.legend(h, l, title="Job Domain")
q.legend(h, l, title="Job Domain")

Example plot