使用seaborn.PairGrid在python中生成相关矩阵时对角线直方图的标题

时间:2019-06-26 08:08:27

标签: python seaborn correlation

花了一些时间确定PairGrid的功能后,我快要在那里了。 下面的代码生成了我想要的图,在histfunc中缺少了一个小细节。我想要的是对角线上绘制的直方图的标题。如何将数据框列名称传递给histfunc?任何想法表示赞赏。

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import dcor
import random
from scipy.stats import linregress
from matplotlib import rc

font = {'family' : 'normal',
        'weight' : 'normal',
        'size'   : 4}
rc('font', **font)   

def corrmat(data):
    def cm2inch(value):
        """helper function for plotting. Converts cm to inch"""
        return value/2.54

    def dist_corr(X, Y, pval=True, nruns=100):
        """ Distance correlation with p-value from bootstrapping"""
        dc = dcor.distance_correlation(X, Y)
        pv = dcor.independence.distance_covariance_test(X, Y, exponent=1.0, num_resamples=nruns)[0]
        if pval:
            return (dc, pv)
        else:
            return dc    

    def linreg(X, Y, pval=True):
        """ Linear regression"""
        r2 = linregress(X,Y)[2]**2
        pv = linregress(X,Y)[3]
        if pval:
            return (r2, pv)
        else:
            return r2               

    def scatterfunc(x, y, **kws):
        # scatterplot with spline of deg=5 in red
        plt.scatter(x, y, linewidths=1, facecolor="k", s=10, alpha = 0.5)
        spline = np.polyfit(x, y, 5)
        model = np.poly1d(spline)
        x = np.sort(x)
        plt.plot(x,model(x),'r-')

    def histfunc(x, **kws):
        #  histogram
        plt.hist(x,bins=30,color = "black", ec="white")    
        """
        vvvvvvvvvvvvvvvvvvvv
        here something like 
        plt.title(label) 
        is missing but the **kws only contain label as string not as 
        parameter contaning the column name
        ^^^^^^^^^^^^^^^^^^^
        """

    def corrfunc(x, y, dc=False, **kws):  
        # different sizes, text anc color in relation to r/d values         
        if dc:
            d, p = dist_corr(x,y) 
        else:    
            d, p = linreg(x,y)

        if d<0.25:
            pclr = 'Black'
            fontsize = 16
        elif (d>=0.25) and (d<0.5):
            pclr = 'Blue'
            fontsize = 20
        elif (d>=0.5) and (p<0.75):
            pclr = 'Orange'
            fontsize = 25
        elif (p>0.75):
            pclr = 'Red'
            fontsize = 30

        if p<0.001:
            ptext = "***"
        elif (p>=0.001) and (p<0.01):
            ptext = "**"
        elif (p>=0.01) and (p<0.05):
            ptext = "*"
        elif (p>0.05):
            ptext = "n.sig"

        ax = plt.gca()
        if dc:
            ax.annotate(''.join(['DC: ',str(np.round(d,2)),'\n\n    ',ptext]),
                        xy=(0.3, 0.3), 
                        xycoords=ax.transAxes, 
                        color = pclr, 
                        fontsize = fontsize)
        else:
            ax.annotate(''.join(['r2: ',str(np.round(d,2)),'\n\n    ',ptext]),
                        xy=(0.3, 0.3), 
                        xycoords=ax.transAxes, 
                        color = pclr, 
                        fontsize = fontsize)

        plt.axis('off')


    plt.figure(num=None, figsize=(cm2inch(15), cm2inch(10)), dpi=300, facecolor='w', edgecolor='k')
    g = sns.PairGrid(data, diag_sharey=False)
    g.map_upper(scatterfunc)
    g.map_diag(histfunc)
    g.map_lower(corrfunc)
    plt.tight_layout()
    plt.show()


########

data = pd.DataFrame(np.random.random([1000,10]),columns=[str(i) for i in range(10)])   
for (i,col) in enumerate(data):
    if i > 1:
        if np.random.random()>0.5:
            data[col]= data[col] * data.iloc[:,random.sample(set(np.arange(0,i)),1 )[0]]
corrmat(data)

它生成的是

enter image description here

1 个答案:

答案 0 :(得分:0)

感谢@ImportanceOfBeingErnest在这里评论更新的skript,以备不时之需。我还将散点图切换为“下部”,以使轴标签变得可见。

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import dcor
import random
from scipy.stats import linregress
from matplotlib import rc

font = {'family' : 'normal',
        'weight' : 'normal',
        'size'   : 16}
rc('font', **font)   

def corrmat(data):
    def cm2inch(value):
        """helper function for plotting. Converts cm to inch"""
        return value/2.54

    def dist_corr(X, Y, pval=True, nruns=100):
        """ Distance correlation with p-value from bootstrapping"""
        dc = dcor.distance_correlation(X, Y)
        pv = dcor.independence.distance_covariance_test(X, Y, exponent=1.0, num_resamples=nruns)[0]
        if pval:
            return (dc, pv)
        else:
            return dc    

    def linreg(X, Y, pval=True):
        """ Linear regression"""
        r2 = linregress(X,Y)[2]**2
        pv = linregress(X,Y)[3]
        if pval:
            return (r2, pv)
        else:
            return r2               

    def scatterfunc(x, y, **kws):
        """ scatterplot with spline of deg=5 in red"""
        plt.scatter(x, y, linewidths=1, facecolor="k", s=10, alpha = 0.5)
        spline = np.polyfit(x, y, 5)
        model = np.poly1d(spline)
        x = np.sort(x)
        plt.plot(x,model(x),'r-')

    def histfunc(x, **kws):
        """ histogram"""
        plt.hist(x,bins=30,color = "black", ec="white")    

    def corrfunc(x, y, dc=False, **kws):  
        """different sizes, text anc color in relation to r/d values
           the dc parameter determines wheter distance correlation or 
           linear regression should be applied"""
        if dc:
            d, p = dist_corr(x,y) 
        else:    
            d, p = linreg(x,y)

        if d<0.25:
            pclr = 'Black'
            fontsize = 16
        elif (d>=0.25) and (d<0.5):
            pclr = 'Blue'
            fontsize = 20
        elif (d>=0.5) and (p<0.75):
            pclr = 'Orange'
            fontsize = 25
        elif (p>0.75):
            pclr = 'Red'
            fontsize = 30

        if p<0.001:
            ptext = "***"
        elif (p>=0.001) and (p<0.01):
            ptext = "**"
        elif (p>=0.01) and (p<0.05):
            ptext = "*"
        elif (p>0.05):
            ptext = "n.sig"

        ax = plt.gca()
        if dc:
            ax.annotate(''.join(['DC: ',str(np.round(d,2)),'\n\n    ',ptext]),
                        xy=(0.3, 0.3), 
                        xycoords=ax.transAxes, 
                        color = pclr, 
                        fontsize = fontsize)
        else:
            ax.annotate(''.join(['r2: ',str(np.round(d,2)),'\n\n    ',ptext]),
                        xy=(0.3, 0.3), 
                        xycoords=ax.transAxes, 
                        color = pclr, 
                        fontsize = fontsize)

        plt.axis('off')

    def make_diag_titles(g,titles):
        for (i,row) in enumerate(g.axes):
            g.axes[i][i].title.set_text(titles[i])
        return g
    ###
    # here the plot is put together
    plt.figure(num=None, figsize=(cm2inch(15), cm2inch(10)), dpi=300, facecolor='w', edgecolor='k')
    g = sns.PairGrid(data, diag_sharey=False)
    g.map_lower(scatterfunc)
    g.map_diag(histfunc)
    g.map_upper(corrfunc)
    g = make_diag_titles(g, data.columns)
    plt.tight_layout()
    plt.show()


########

data = pd.DataFrame(np.random.random([1000,10]),columns=[str(i) for i in range(10)])   
for (i,col) in enumerate(data):
    if i > 1:
        if np.random.random()>0.5:
            data[col]= data[col] * data.iloc[:,random.sample(set(np.arange(0,i)),1 )[0]]
corrmat(data)

enter image description here