Python线性回归诊断图与R类似

时间:2017-10-06 14:15:37

标签: python plot regression linear-regression

我试图在Python中获得线性回归的诊断图,我想知道是否有快速的方法来做到这一点。

在R中,您可以使用下面的代码片段,其中包含残差与拟合图,正常Q-Q图,比例位置,残差与杠杆图。

m1 <- lm(cost~ distance, data = df1)
summary(m1)
plot(m1)

在python中有快速的方法吗?

这是一篇很棒的博客文章,描述了如何使用Python代码获得与R给你的相同的图,但它需要相当多的代码(至少与R方法相比)。链接:https://underthecurve.github.io/jekyll/update/2016/07/01/one-regression-six-ways.html#Python

1 个答案:

答案 0 :(得分:0)

我更喜欢将所有内容存储在pandas中,并在可能的情况下使用DataFrame.plot()进行绘制:

from matplotlib import pyplot as plt
from pandas.core.frame import DataFrame
import scipy.stats as stats
import statsmodels.api as sm


def linear_regression(df: DataFrame) -> DataFrame:
    """Perform a univariate regression and store results in a new data frame.

    Args:
        df (DataFrame): orginal data set with x and y.

    Returns:
        DataFrame: another dataframe with raw data and results.
    """
    mod = sm.OLS(endog=df['y'], exog=df['x']).fit()
    influence = mod.get_influence()

    res = df.copy()
    res['resid'] = mod.resid
    res['fittedvalues'] = mod.fittedvalues
    res['resid_std'] = mod.resid_pearson
    res['leverage'] = influence.hat_matrix_diag
    return res


def plot_diagnosis(df: DataFrame):
    fig, axes = plt.subplots(nrows=2, ncols=2)
    plt.style.use('seaborn')

    # Residual against fitted values.
    df.plot.scatter(
        x='fittedvalues', y='resid', ax=axes[0, 0]
    )
    axes[0, 0].axhline(y=0, color='grey', linestyle='dashed')
    axes[0, 0].set_xlabel('Fitted Values')
    axes[0, 0].set_ylabel('Residuals')
    axes[0, 0].set_title('Residuals vs Fitted')

    # qqplot
    sm.qqplot(
        df['resid'], dist=stats.t, fit=True, line='45',
        ax=axes[0, 1], c='#4C72B0'
    )
    axes[0, 1].set_title('Normal Q-Q')

    # The scale-location plot.
    df.plot.scatter(
        x='fittedvalues', y='resid_std', ax=axes[1, 0]
    )
    axes[1, 0].axhline(y=0, color='grey', linestyle='dashed')
    axes[1, 0].set_xlabel('Fitted values')
    axes[1, 0].set_ylabel('Sqrt(|standardized residuals|)')
    axes[1, 0].set_title('Scale-Location')

    # Standardized residuals vs. leverage
    df.plot.scatter(
        x='leverage', y='resid_std', ax=axes[1, 1]
    )
    axes[1, 1].axhline(y=0, color='grey', linestyle='dashed')
    axes[1, 1].set_xlabel('Leverage')
    axes[1, 1].set_ylabel('Sqrt(|standardized residuals|)')
    axes[1, 1].set_title('Residuals vs Leverage')

    plt.tight_layout()
    plt.show()

仍然缺少许多功能,但它提供了一个良好的开端。我在这里Access standardized residuals, cook's values, hatvalues (leverage) etc. easily in Python?

了解了如何提取影响力统计信息

enter image description here

顺便说一下,有一个具有所有功能的软件包dynobo/lmdiag