如何将不同的散布kwargs传递到Seaborn的lmplot中

时间:2016-11-27 13:25:20

标签: matplotlib seaborn

我试图将第3个变量映射到Seaborn lmplot中的散点颜色。因此,total_bill位于x上,tip位于y上,而点颜色位于size的位置。

当没有启用分面但在使用col时失败,因为颜色数组大小与每个构面中绘制的数据大小不匹配,它会起作用。

这是我的代码

    import matplotlib as mpl
    import seaborn as sns
    sns.set(color_codes=True)

    # load data
    data = sns.load_dataset("tips")

    # size of data
    print len(data.index)

    ### we want to plot scatter point colour as function of variable 'size'

    # first, sort the data by 'size' so that high 'size' values are plotted
    # over the smaller sizes (so they are more visible)

    data = data.sort_values(by=['size'], ascending=True)

    scatter_kws = dict()
    cmap = mpl.cm.get_cmap(name='Blues')

    # normalise 'size' variable as float range needs to be
    # between 0 and 1 to map to a valid colour
    scatter_kws['c'] = data['size'] / data['size'].max()

    # map normalised values to colours
    scatter_kws['c'] = cmap(scatter_kws['c'].values)

    # colour array has same size as data
    print len(scatter_kws['c'])

    # this works as intended
    g = sns.lmplot(data=data, x="total_bill", y="tip", scatter_kws=scatter_kws)

以上效果很好并产生以下内容(不允许包含图片,所以这里是链接):

lmplot with point colour as function of size

然而,当我将col='sex'添加到lmplot(尝试下面的代码)时,问题是颜色数组的原始数据集的大小大于每个方面中绘制的数据的大小。因此,例如col='male'有157个数据点,因此来自颜色数组的前157个值被映射到这些点(这些点甚至不是正确的)。见下文:

lmplot with point colour as function of size with col=sex

    g = sns.lmplot(data=data, x="total_bill", y="tip", col="sex", scatter_kws=scatter_kws)

理想情况下,我想将一个scatter_kws数组传递给lmplot,以便每个facet使用正确的颜色数组(我在传递给lmplot之前计算)。但这似乎不是一种选择。

仍然允许我使用Seaborn的lmplot功能的任何其他想法或解决方法(意思是,不需要从lmplot重新创建FacetGrid功能?< / p>

1 个答案:

答案 0 :(得分:0)

原则上,具有不同lmplot的{​​{1}}似乎只是几个cols的包装器。因此,我们可以使用两个regplot而不是一个lmplot,每个regplots一个sex

因此,我们需要将原始数据框分为malefemale,其余部分则相当直接。

import matplotlib.pyplot as plt
import seaborn as sns

data = sns.load_dataset("tips")

data = data.sort_values(by=['size'], ascending=True)
# make a new dataframe for males and females
male = data[data["sex"] == "Male"]
female = data[data["sex"] == "Female"]

# get normalized colors for all data
colors = data['size'].values / float(data['size'].max())
# get colors for males / females
colors_male = colors[data["sex"].values == "Male"]
colors_female = colors[data["sex"].values == "Female"]
# colors are values in [0,1] range


fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(9,4))

#create regplot for males, put it to left axes
#use colors_male to color the points with Blues cmap
sns.regplot(data=male, x="total_bill", y="tip", ax=ax1, 
            scatter_kws= {"c" : colors_male, "cmap":"Blues"})
# same for females
sns.regplot(data=female, x="total_bill", y="tip", ax=ax2, 
            scatter_kws={"c" : colors_female, "cmap":"Greens"})
ax1.set_title("Males")
ax2.set_title("Females")
for ax in [ax1, ax2]:
    ax.set_xlim([0,60])
    ax.set_ylim([0,12])
plt.tight_layout()
plt.show()

enter image description here